diff --git a/.github/codecov.yml b/.github/codecov.yml deleted file mode 100644 index 5f721427d7a..00000000000 --- a/.github/codecov.yml +++ /dev/null @@ -1,10 +0,0 @@ -# we measure coverage but don't enforce it -# https://docs.codecov.com/docs/codecov-yaml -coverage: - status: - patch: - default: - target: 0% - project: - default: - target: 0% diff --git a/.github/generate-codecov-yml.sh b/.github/generate-codecov-yml.sh new file mode 100755 index 00000000000..ddb60d0ce80 --- /dev/null +++ b/.github/generate-codecov-yml.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Run this from the repository root: +# +# .github/generate-codecov-yml.sh >> .github/codecov.yml + +cat <> $GITHUB_ENV - name: "Create Parsers badge" - uses: schneegans/dynamic-badges-action@v1.6.0 + uses: schneegans/dynamic-badges-action@v1.7.0 if: ${{ github.ref == 'refs/heads/master' && github.repository_owner == 'crowdsecurity' }} with: auth: ${{ secrets.GIST_BADGES_SECRET }} @@ -67,7 +64,7 @@ jobs: color: ${{ env.SCENARIO_BADGE_COLOR }} - name: "Create Scenarios badge" - uses: schneegans/dynamic-badges-action@v1.6.0 + uses: schneegans/dynamic-badges-action@v1.7.0 if: ${{ github.ref == 'refs/heads/master' && github.repository_owner == 'crowdsecurity' }} with: auth: ${{ secrets.GIST_BADGES_SECRET }} diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index 902c25ba329..211d856bc34 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -7,9 +7,6 @@ on: required: true type: string -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: name: "Functional tests" @@ -39,7 +36,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: "Install bats dependencies" env: @@ -58,7 +55,7 @@ jobs: MYSQL_USER: root - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: mysql MYSQL_HOST: 127.0.0.1 diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index e15f1e410c1..aec707f0c03 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -3,9 +3,6 @@ name: (sub) Bats / Postgres on: workflow_call: -env: - PREFIX_TEST_NAMES_WITH_FILE: true - jobs: build: name: "Functional tests" @@ -48,7 +45,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: "Install bats dependencies" env: @@ -67,7 +64,7 @@ jobs: PGUSER: postgres - name: "Run tests (DB_BACKEND: pgx)" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter env: DB_BACKEND: pgx PGHOST: 127.0.0.1 diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index 36194555e1d..a089aa53532 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -2,9 +2,11 @@ name: (sub) Bats / sqlite + coverage on: workflow_call: + secrets: + CODECOV_TOKEN: + required: true env: - PREFIX_TEST_NAMES_WITH_FILE: true TEST_COVERAGE: true jobs: @@ -29,7 +31,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: "Install bats dependencies" env: @@ -41,8 +43,12 @@ jobs: run: | make clean bats-build bats-fixture BUILD_STATIC=1 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: "Run tests" - run: make bats-test + run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter - name: "Collect coverage data" run: | @@ -77,8 +83,9 @@ jobs: run: for file in $(find ./test/local/var/log -type f); do echo ">>>>> $file"; cat $file; echo; done if: ${{ always() }} - - name: Upload crowdsec coverage to codecov - uses: codecov/codecov-action@v3 + - name: Upload bats coverage to codecov + uses: codecov/codecov-action@v4 with: files: ./coverage-bats.out flags: bats + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/bats.yml b/.github/workflows/bats.yml index 0ce8cf041ed..59976bad87d 100644 --- a/.github/workflows/bats.yml +++ b/.github/workflows/bats.yml @@ -28,6 +28,8 @@ on: jobs: sqlite: uses: ./.github/workflows/bats-sqlite-coverage.yml + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # Jobs for Postgres (and sometimes MySQL) can have failing tests on GitHub # CI, but they pass when run on devs' machines or in the release checks. We diff --git a/.github/workflows/ci-windows-build-msi.yml b/.github/workflows/ci-windows-build-msi.yml index 26c981143ad..a37aa43e2d0 100644 --- a/.github/workflows/ci-windows-build-msi.yml +++ b/.github/workflows/ci-windows-build-msi.yml @@ -35,12 +35,12 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: Build run: make windows_installer BUILD_RE2_WASM=1 - name: Upload MSI - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: crowdsec*msi name: crowdsec.msi diff --git a/.github/workflows/ci_release-drafter.yml b/.github/workflows/ci_release-drafter.yml index 2ccb6977cfd..0b8c9b386e6 100644 --- a/.github/workflows/ci_release-drafter.yml +++ b/.github/workflows/ci_release-drafter.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: # Drafts your next Release notes as Pull Requests are merged into "master" - - uses: release-drafter/release-drafter@v5 + - uses: release-drafter/release-drafter@v6 with: config-name: release-drafter.yml # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0904769dd60..2715c6590c3 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -49,9 +49,15 @@ jobs: # required to pick up tags for BUILD_VERSION fetch-depth: 0 + - name: "Set up Go" + uses: actions/setup-go@v5 + with: + go-version: "1.22" + cache-dependency-path: "**/go.sum" + # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -62,7 +68,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) # - name: Autobuild - # uses: github/codeql-action/autobuild@v2 + # uses: github/codeql-action/autobuild@v3 # ℹ️ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -71,14 +77,8 @@ jobs: # and modify them (or add more) to build your code if your project # uses a compiled language - - name: "Set up Go" - uses: actions/setup-go@v5 - with: - go-version: "1.21.6" - cache-dependency-path: "**/go.sum" - - run: | make clean build BUILD_RE2_WASM=1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 7bc63de0178..918f3bcaf1d 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -35,10 +35,10 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: - config: .github/buildkit.toml + buildkitd-config: .github/buildkit.toml - name: "Build image" - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile${{ matrix.flavor == 'debian' && '.debian' || '' }} @@ -50,26 +50,15 @@ jobs: cache-to: type=gha,mode=min - name: "Setup Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - - - name: "Install pipenv" - run: | - cd docker/test - python -m pip install --upgrade pipenv wheel - - - name: "Cache virtualenvs" - id: cache-pipenv - uses: actions/cache@v3 - with: - path: ~/.local/share/virtualenvs - key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} + cache: 'pipenv' - name: "Install dependencies" - if: steps.cache-pipenv.outputs.cache-hit != 'true' run: | cd docker/test + python -m pip install --upgrade pipenv wheel pipenv install --deploy - name: "Create Docker network" diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 63781a7b25e..ba283f3890a 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -34,32 +34,33 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: Build run: | make build BUILD_RE2_WASM=1 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: Run tests run: | go install github.com/kyoh86/richgo@v0.3.10 - go test -coverprofile coverage.out -covermode=atomic ./... > out.txt + go test -tags expr_debug -coverprofile coverage.out -covermode=atomic ./... > out.txt if(!$?) { cat out.txt | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter; Exit 1 } cat out.txt | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter - name: Upload unit coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: files: coverage.out flags: unit-windows + token: ${{ secrets.CODECOV_TOKEN }} - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.55 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false - # the cache is already managed above, enabling it here - # gives errors when extracting - skip-pkg-cache: true - skip-build-cache: true diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index e8840c07f4e..3fdfb8a3e82 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -126,13 +126,40 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" + + - name: Run "go generate" and check for changes + run: | + set -e + # ensure the version of 'protoc' matches the one that generated the files + PROTOBUF_VERSION="21.12" + # don't pollute the repo + pushd $HOME + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip + unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d $HOME/.protoc + popd + export PATH="$HOME/.protoc/bin:$PATH" + go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + go generate ./... + protoc --version + if [[ $(git status --porcelain) ]]; then + echo "Error: Uncommitted changes found after running 'make generate'. Please commit all generated code." + git diff + exit 1 + else + echo "No changes detected after running 'make generate'." + fi - name: Create localstack streams run: | aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-1-shard --shard-count 1 aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-2-shards --shard-count 2 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: Build and run tests, static run: | sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential libre2-dev @@ -142,6 +169,11 @@ jobs: make build BUILD_STATIC=1 make go-acc | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter + # check if some component stubs are missing + - name: "Build profile: minimal" + run: | + make build BUILD_PROFILE=minimal + - name: Run tests again, dynamic run: | make clean build @@ -149,18 +181,15 @@ jobs: make go-acc | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter - name: Upload unit coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: files: coverage.out flags: unit-linux + token: ${{ secrets.CODECOV_TOKEN }} - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.55 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false - # the cache is already managed above, enabling it here - # gives errors when extracting - skip-pkg-cache: true - skip-build-cache: true diff --git a/.github/workflows/governance-bot.yaml b/.github/workflows/governance-bot.yaml index 5c08cabf5d1..c9e73e7811a 100644 --- a/.github/workflows/governance-bot.yaml +++ b/.github/workflows/governance-bot.yaml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest steps: # Semantic versioning, lock to different version: v2, v2.0 or a commit hash. - - uses: BirthdayResearch/oss-governance-bot@v3 + - uses: BirthdayResearch/oss-governance-bot@v4 with: # You can use a PAT to post a comment/label/status so that it shows up as a user instead of github-actions github-token: ${{secrets.GITHUB_TOKEN}} # optional, default to '${{ github.token }}' diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml index 005db0cc9d1..11b4401c6da 100644 --- a/.github/workflows/publish-docker.yml +++ b/.github/workflows/publish-docker.yml @@ -47,7 +47,7 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: - config: .github/buildkit.toml + buildkitd-config: .github/buildkit.toml - name: Login to DockerHub uses: docker/login-action@v3 @@ -93,7 +93,7 @@ jobs: - name: Build and push image (slim) if: ${{ inputs.slim }} - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile${{ inputs.debian && '.debian' || '' }} @@ -109,7 +109,7 @@ jobs: BUILD_VERSION=${{ inputs.crowdsec_version }} - name: Build and push image (full) - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile${{ inputs.debian && '.debian' || '' }} diff --git a/.github/workflows/publish-tarball-release.yml b/.github/workflows/publish-tarball-release.yml index 202882791e7..eeefb801719 100644 --- a/.github/workflows/publish-tarball-release.yml +++ b/.github/workflows/publish-tarball-release.yml @@ -25,7 +25,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.21.6" + go-version: "1.22" - name: Build the binaries run: | diff --git a/.gitignore b/.gitignore index 3054e9eb3c2..d76efcbfc48 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,10 @@ *.dylib *~ .pc + +# IDEs .vscode +.idea # If vendor is included, allow prebuilt (wasm?) libraries. !vendor/**/*.so @@ -34,7 +37,7 @@ test/coverage/* *.swo # Dependencies are not vendored by default, but a tarball is created by "make vendor" -# and provided in the release. Used by freebsd, gentoo, etc. +# and provided in the release. Used by gentoo, etc. vendor/ vendor.tgz @@ -57,3 +60,6 @@ msi __pycache__ *.py[cod] *.egg-info + +# automatically generated before running codecov +.github/codecov.yml diff --git a/.golangci.yml b/.golangci.yml index e1f2fc09a84..4909d3e60c0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,18 +1,6 @@ # https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -run: - skip-dirs: - - pkg/time/rate - skip-files: - - pkg/database/ent/generate.go - - pkg/yamlpatch/merge.go - - pkg/yamlpatch/merge_test.go - linters-settings: - cyclop: - # lower this after refactoring - max-complexity: 66 - gci: sections: - standard @@ -20,43 +8,29 @@ linters-settings: - prefix(github.com/crowdsecurity) - prefix(github.com/crowdsecurity/crowdsec) - gocognit: - # lower this after refactoring - min-complexity: 145 - - gocyclo: - # lower this after refactoring - min-complexity: 64 - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: -1 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: -1 + gomoddirectives: + replace-allow-list: + - golang.org/x/time/rate govet: - check-shadowing: true - - lll: - line-length: 140 + enable-all: true + disable: + - reflectvaluecompare + - fieldalignment maintidx: # raise this after refactoring - under: 9 + under: 15 misspell: locale: US nestif: # lower this after refactoring - min-complexity: 27 + min-complexity: 16 nlreturn: - block-size: 4 + block-size: 5 nolintlint: allow-unused: false # report any unused nolint directives @@ -68,10 +42,168 @@ linters-settings: depguard: rules: - main: + wrap: deny: - pkg: "github.com/pkg/errors" desc: "errors.Wrap() is deprecated in favor of fmt.Errorf()" + files: + - "!**/pkg/database/*.go" + yaml: + files: + - "!**/pkg/acquisition/acquisition.go" + - "!**/pkg/acquisition/acquisition_test.go" + - "!**/pkg/acquisition/modules/appsec/appsec.go" + - "!**/pkg/acquisition/modules/cloudwatch/cloudwatch.go" + - "!**/pkg/acquisition/modules/docker/docker.go" + - "!**/pkg/acquisition/modules/file/file.go" + - "!**/pkg/acquisition/modules/journalctl/journalctl.go" + - "!**/pkg/acquisition/modules/kafka/kafka.go" + - "!**/pkg/acquisition/modules/kinesis/kinesis.go" + - "!**/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go" + - "!**/pkg/acquisition/modules/loki/loki.go" + - "!**/pkg/acquisition/modules/loki/timestamp_test.go" + - "!**/pkg/acquisition/modules/s3/s3.go" + - "!**/pkg/acquisition/modules/syslog/syslog.go" + - "!**/pkg/acquisition/modules/wineventlog/wineventlog_windows.go" + - "!**/pkg/appsec/appsec.go" + - "!**/pkg/appsec/loader.go" + - "!**/pkg/csplugin/broker.go" + - "!**/pkg/leakybucket/buckets_test.go" + - "!**/pkg/leakybucket/manager_load.go" + - "!**/pkg/parser/node.go" + - "!**/pkg/parser/node_test.go" + - "!**/pkg/parser/parsing_test.go" + - "!**/pkg/parser/stage.go" + deny: + - pkg: "gopkg.in/yaml.v2" + desc: "yaml.v2 is deprecated for new code in favor of yaml.v3" + + stylecheck: + checks: + - all + - -ST1003 # should not use underscores in Go names; ... + - -ST1005 # error strings should not be capitalized + - -ST1012 # error var ... should have name of the form ErrFoo + - -ST1016 # methods on the same type should have the same receiver name + - -ST1022 # comment on exported var ... should be of the form ... + + revive: + ignore-generated-header: true + severity: error + enable-all-rules: true + rules: + - name: add-constant + disabled: true + - name: cognitive-complexity + # lower this after refactoring + arguments: [119] + - name: comment-spacings + disabled: true + - name: confusing-results + disabled: true + - name: cyclomatic + # lower this after refactoring + arguments: [39] + - name: defer + disabled: true + - name: empty-block + disabled: true + - name: empty-lines + disabled: true + - name: error-naming + disabled: true + - name: flag-parameter + disabled: true + - name: function-result-limit + arguments: [6] + - name: function-length + # lower this after refactoring + arguments: [110, 237] + - name: get-return + disabled: true + - name: increment-decrement + disabled: true + - name: import-alias-naming + disabled: true + - name: import-shadowing + disabled: true + - name: line-length-limit + # lower this after refactoring + arguments: [221] + - name: max-control-nesting + # lower this after refactoring + arguments: [7] + - name: max-public-structs + disabled: true + - name: nested-structs + disabled: true + - name: package-comments + disabled: true + - name: redundant-import-alias + disabled: true + - name: time-equal + disabled: true + - name: var-naming + disabled: true + - name: unchecked-type-assertion + disabled: true + - name: exported + disabled: true + - name: unexported-naming + disabled: true + - name: unexported-return + disabled: true + - name: unhandled-error + disabled: true + arguments: + - "fmt.Print" + - "fmt.Printf" + - "fmt.Println" + - name: unnecessary-stmt + disabled: true + - name: unused-parameter + disabled: true + - name: unused-receiver + disabled: true + - name: use-any + disabled: true + - name: useless-break + disabled: true + + wsl: + # Allow blocks to end with comments + allow-trailing-comment: true + + gocritic: + enable-all: true + disabled-checks: + - typeDefFirst + - paramTypeCombine + - httpNoBody + - ifElseChain + - importShadow + - hugeParam + - rangeValCopy + - commentedOutCode + - commentedOutImport + - unnamedResult + - sloppyReassign + - appendCombine + - captLocal + - typeUnparen + - commentFormatting + - deferInLoop # + - sprintfQuotedString # + - whyNoLint + - equalFold # + - unnecessaryBlock # + - ptrToRefParam # + - stringXbytes # + - appendAssign # + - tooManyResultsChecker + - unnecessaryDefer + - docStub + - preferFprint linters: enable-all: true @@ -79,93 +211,42 @@ linters: # # DEPRECATED by golangi-lint # - - deadcode # The owner seems to have abandoned the linter. Replaced by unused. - - exhaustivestruct # The owner seems to have abandoned the linter. Replaced by exhaustruct. - - golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - - ifshort # Checks that your code uses short syntax for if-statements whenever possible - - interfacer # Linter that suggests narrower interface types - - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - - nosnakecase # nosnakecase is a linter that detects snake case of variable naming and function name. - - scopelint # Scopelint checks for unpinned variables in go programs - - structcheck # The owner seems to have abandoned the linter. Replaced by unused. - - varcheck # The owner seems to have abandoned the linter. Replaced by unused. + - execinquery + - exportloopref + - gomnd # - # Enabled + # Redundant # - # - asasalint # check for pass []any as any in variadic func(...any) - # - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - # - bidichk # Checks for dangerous unicode character sequences - # - bodyclose # checks whether HTTP response body is closed successfully - # - cyclop # checks function and package cyclomatic complexity - # - decorder # check declaration order and count of types, constants, variables and functions - # - depguard # Go linter that checks if package imports are in a list of acceptable packages - # - dupword # checks for duplicate words in the source code - # - durationcheck # check for two durations multiplied together - # - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - # - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - # - exportloopref # checks for pointers to enclosing loop variables - # - funlen # Tool for detection of long functions - # - ginkgolinter # enforces standards of using ginkgo and gomega - # - gochecknoinits # Checks that no init functions are present in Go code - # - gocognit # Computes and checks the cognitive complexity of functions - # - gocritic # Provides diagnostics that check for bugs, performance and style issues. - # - gocyclo # Computes and checks the cyclomatic complexity of functions - # - goheader # Checks is file header matches to pattern - # - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - # - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - # - goprintffuncname # Checks that printf-like functions are named with `f` at the end - # - gosimple # (megacheck): Linter for Go source code that specializes in simplifying a code - # - govet # (vet, vetshadow): Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - # - grouper # An analyzer to analyze expression groups. - # - importas # Enforces consistent import aliases - # - ineffassign # Detects when assignments to existing variables are not used - # - interfacebloat # A linter that checks the number of methods inside an interface. - # - logrlint # Check logr arguments. - # - maintidx # maintidx measures the maintainability index of each function. - # - makezero # Finds slice declarations with non-zero initial length - # - misspell # Finds commonly misspelled English words in comments - # - nakedret # Finds naked returns in functions greater than a specified function length - # - nestif # Reports deeply nested if statements - # - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - # - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - # - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - # - predeclared # find code that shadows one of Go's predeclared identifiers - # - reassign # Checks that package variables are not reassigned - # - rowserrcheck # checks whether Err of rows is checked successfully - # - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - # - staticcheck # (megacheck): Staticcheck is a go vet on steroids, applying a ton of static analysis checks - # - testableexamples # linter checks if examples are testable (have an expected output) - # - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - # - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - # - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - # - unconvert # Remove unnecessary type conversions - # - unused # (megacheck): Checks Go code for unused constants, variables, functions and types - # - usestdlibvars # A linter that detect the possibility to use variables/constants from the Go standard library. - # - wastedassign # wastedassign finds wasted assignment statements. + - gocyclo # revive + - cyclop # revive + - lll # revive + - funlen # revive + - gocognit # revive + + # Disabled atm + + - intrange # intrange is a linter to find places where for loops could make use of an integer range. # # Recommended? (easy) # - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and reports occasions, where the check for the returned error can be omitted. - exhaustive # check exhaustiveness of enum switch statements - gci # Gci control golang package import order and make it always deterministic. - godot # Check if comments end in a period - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. + - goimports # Check import statements are formatted according to the 'goimport' command. Reformat imports in autofix mode. - gosec # (gas): Inspects source code for security problems - inamedparam # reports interfaces with unnamed method parameters - - lll # Reports long lines - musttag # enforce field tags in (un)marshaled structs - promlinter # Check Prometheus metrics naming via promlint - protogetter # Reports direct reads from proto message fields when getters should be used - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - tagalign # check that struct tags are well aligned - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - thelper # thelper detects tests helpers which is not start with t.Helper() method. - wrapcheck # Checks that errors returned from external packages are wrapped # @@ -173,12 +254,12 @@ linters: # - containedctx # containedctx is a linter that detects struct contained context.Context field - - contextcheck # check the function whether use a non-inherited context + - contextcheck # check whether the function uses a non-inherited context - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - - noctx # noctx finds sending http request without context.Context + - noctx # Finds sending http request without context.Context - unparam # Reports unused function parameters # @@ -187,8 +268,8 @@ linters: - gofumpt # Gofumpt checks whether code was gofumpt-ed. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - - whitespace # Tool for detection of leading and trailing whitespace - - wsl # Whitespace Linter - Forces you to use empty lines! + - whitespace # Whitespace is a linter that checks for unnecessary newlines at the start and end of functions, if, for, etc. + - wsl # add or remove empty lines # # Well intended, but not ready for this @@ -196,19 +277,17 @@ linters: - dupl # Tool for code clone detection - forcetypeassert # finds forced type assertions - godox # Tool for detection of FIXME, TODO and other comment keywords - - goerr113 # Golang linter to check the errors handling expressions - - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - err113 # Go linter to check the errors handling expressions + - paralleltest # Detects missing usage of t.Parallel() method in your Go test - testpackage # linter that makes you use a separate _test package # # Too strict / too many false positives (for now?) # - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustruct # Checks if all structure fields are initialized - forbidigo # Forbids identifiers - - gochecknoglobals # check that no global variables exist + - gochecknoglobals # Check that no global variables exist. - goconst # Finds repeated strings that could be replaced by a constant - - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - varnamelen # checks that the length of a variable's name matches its scope @@ -223,40 +302,31 @@ issues: # “Look, that’s why there’s rules, understand? So that you think before you # break ‘em.” ― Terry Pratchett + exclude-dirs: + - pkg/time/rate + - pkg/metabase + + exclude-files: + - pkg/yamlpatch/merge.go + - pkg/yamlpatch/merge_test.go + + exclude-generated: strict + max-issues-per-linter: 0 max-same-issues: 0 exclude-rules: # Won't fix: - - path: go.mod - text: "replacement are not allowed: golang.org/x/time/rate" - # `err` is often shadowed, we may continue to do it - linters: - govet - text: "shadow: declaration of \"err\" shadows declaration" + text: "shadow: declaration of \"(err|ctx)\" shadows declaration" - linters: - errcheck text: "Error return value of `.*` is not checked" - - linters: - - gocritic - text: "ifElseChain: rewrite if-else to switch statement" - - - linters: - - gocritic - text: "captLocal: `.*' should not be capitalized" - - - linters: - - gocritic - text: "appendAssign: append result not assigned to the same slice" - - - linters: - - gocritic - text: "commentFormatting: put a space between `//` and comment text" - # Will fix, trivial - just beware of merge conflicts - linters: @@ -279,18 +349,10 @@ issues: - errorlint text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors" - - linters: - - errorlint - text: "type assertion on error will fail on wrapped errors. Use errors.Is to check for specific errors" - - linters: - errorlint text: "comparing with .* will fail on wrapped errors. Use errors.Is to check for a specific error" - - linters: - - errorlint - text: "switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors" - - linters: - nosprintfhostport text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf" @@ -306,16 +368,96 @@ issues: - nonamedreturns text: "named return .* with type .* found" - # - # Will fix, might be trickier - # + - linters: + - revive + path: pkg/leakybucket/manager_load.go + text: "confusing-naming: Field '.*' differs only by capitalization to other field in the struct type BucketFactory" - linters: - - staticcheck - text: "x509.ParseCRL has been deprecated since Go 1.19: Use ParseRevocationList instead" + - revive + path: pkg/exprhelpers/helpers.go + text: "confusing-naming: Method 'flatten' differs only by capitalization to function 'Flatten' in the same source file" - # https://github.com/pkg/errors/issues/245 - linters: - - depguard - text: "import 'github.com/pkg/errors' is not allowed .*" + - revive + path: pkg/appsec/query_utils.go + text: "confusing-naming: Method 'parseQuery' differs only by capitalization to function 'ParseQuery' in the same source file" + - linters: + - revive + path: pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go + text: "confusing-naming: Method 'QueryRange' differs only by capitalization to method 'queryRange' in the same source file" + + - linters: + - revive + path: cmd/crowdsec-cli/copyfile.go + + - linters: + - revive + path: pkg/hubtest/hubtest_item.go + text: "cyclomatic: .*RunWithLogFile" + + # tolerate complex functions in tests for now + - linters: + - maintidx + path: "(.+)_test.go" + + # tolerate long functions in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "function-length: .*" + + # tolerate long lines in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "line-length-limit: .*" + + # tolerate deep exit in tests, for now + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "deep-exit: .*" + + # we use t,ctx instead of ctx,t in tests + - linters: + - revive + path: "pkg/(.+)_test.go" + text: "context-as-argument: context.Context should be the first parameter of a function" + + # tolerate deep exit in cobra's OnInitialize, for now + - linters: + - revive + path: "cmd/crowdsec-cli/main.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec-cli/clihub/item_metrics.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec-cli/idgen/password.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "pkg/leakybucket/overflows.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec/crowdsec.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec/api.go" + text: "deep-exit: .*" + + - linters: + - revive + path: "cmd/crowdsec/win_service.go" + text: "deep-exit: .*" diff --git a/Dockerfile b/Dockerfile index 2369c09dfa6..450ea69017f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.21.6-alpine3.18 AS build +FROM golang:1.22-alpine3.20 AS build ARG BUILD_VERSION @@ -11,27 +11,27 @@ ENV BUILD_VERSION=${BUILD_VERSION} # wizard.sh requires GNU coreutils RUN apk add --no-cache git g++ gcc libc-dev make bash gettext binutils-gold coreutils pkgconfig && \ - wget https://github.com/google/re2/archive/refs/tags/${RE2_VERSION}.tar.gz && \ + wget -q https://github.com/google/re2/archive/refs/tags/${RE2_VERSION}.tar.gz && \ tar -xzf ${RE2_VERSION}.tar.gz && \ cd re2-${RE2_VERSION} && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.40.4 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . -RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ +RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 CGO_CFLAGS="-D_LARGEFILE64_SOURCE" && \ cd crowdsec-v* && \ ./wizard.sh --docker-mode && \ cd - >/dev/null && \ - cscli hub update && \ + cscli hub update --with-content && \ cscli collections install crowdsecurity/linux && \ cscli parsers install crowdsecurity/whitelists # In case we need to remove agents here.. # cscli machines list -o json | yq '.[].machineId' | xargs -r cscli machines delete -FROM alpine:latest as slim +FROM alpine:latest AS slim RUN apk add --no-cache --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community tzdata bash rsync && \ mkdir -p /staging/etc/crowdsec && \ @@ -43,11 +43,12 @@ COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/l COPY --from=build /etc/crowdsec /staging/etc/crowdsec COPY --from=build /go/src/crowdsec/docker/docker_start.sh / COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml +COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml -ENTRYPOINT /bin/bash /docker_start.sh +ENTRYPOINT ["/bin/bash", "/docker_start.sh"] -FROM slim as plugins +FROM slim AS full # Due to the wizard using cp -n, we have to copy the config files directly from the source as -n does not exist in busybox cp # The files are here for reference, as users will need to mount a new version to be actually able to use notifications @@ -60,11 +61,3 @@ COPY --from=build \ /staging/etc/crowdsec/notifications/ COPY --from=build /usr/local/lib/crowdsec/plugins /usr/local/lib/crowdsec/plugins - -FROM slim as geoip - -COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec - -FROM plugins as full - -COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec diff --git a/Dockerfile.debian b/Dockerfile.debian index ba0cd20fb43..8bf2698c786 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.21.6-bookworm AS build +FROM golang:1.22-bookworm AS build ARG BUILD_VERSION @@ -21,7 +21,7 @@ RUN apt-get update && \ make && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.40.4 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . @@ -29,14 +29,14 @@ RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ cd crowdsec-v* && \ ./wizard.sh --docker-mode && \ cd - >/dev/null && \ - cscli hub update && \ + cscli hub update --with-content && \ cscli collections install crowdsecurity/linux && \ cscli parsers install crowdsecurity/whitelists # In case we need to remove agents here.. # cscli machines list -o json | yq '.[].machineId' | xargs -r cscli machines delete -FROM debian:bookworm-slim as slim +FROM debian:bookworm-slim AS slim ENV DEBIAN_FRONTEND=noninteractive ENV DEBCONF_NOWARNINGS="yes" @@ -62,9 +62,9 @@ COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/conf RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml && \ yq eval -i ".plugin_config.group = \"nogroup\"" /staging/etc/crowdsec/config.yaml -ENTRYPOINT /bin/bash docker_start.sh +ENTRYPOINT ["/bin/bash", "docker_start.sh"] -FROM slim as plugins +FROM slim AS plugins # Due to the wizard using cp -n, we have to copy the config files directly from the source as -n does not exist in busybox cp # The files are here for reference, as users will need to mount a new version to be actually able to use notifications @@ -78,10 +78,10 @@ COPY --from=build \ COPY --from=build /usr/local/lib/crowdsec/plugins /usr/local/lib/crowdsec/plugins -FROM slim as geoip +FROM slim AS geoip COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec -FROM plugins as full +FROM plugins AS full COPY --from=build /var/lib/crowdsec /staging/var/lib/crowdsec diff --git a/Makefile b/Makefile index 5d656165fa8..bbfa4bbee94 100644 --- a/Makefile +++ b/Makefile @@ -25,10 +25,6 @@ BUILD_STATIC ?= 0 # List of plugins to build PLUGINS ?= $(patsubst ./cmd/notification-%,%,$(wildcard ./cmd/notification-*)) -# Can be overriden, if you can deal with the consequences -BUILD_REQUIRE_GO_MAJOR ?= 1 -BUILD_REQUIRE_GO_MINOR ?= 21 - #-------------------------------------- GO = go @@ -78,10 +74,11 @@ LD_OPTS_VARS= \ -X '$(GO_MODULE_NAME)/pkg/csconfig.defaultDataDir=$(DEFAULT_DATADIR)' ifneq (,$(DOCKER_BUILD)) -LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.System=docker' +LD_OPTS_VARS += -X 'github.com/crowdsecurity/go-cs-lib/version.System=docker' endif -GO_TAGS := netgo,osusergo,sqlite_omit_load_extension +#expr_debug tag is required to enable the debug mode in expr +GO_TAGS := netgo,osusergo,sqlite_omit_load_extension,expr_debug # this will be used by Go in the make target, some distributions require it export PKG_CONFIG_PATH:=/usr/local/lib/pkgconfig:$(PKG_CONFIG_PATH) @@ -118,6 +115,69 @@ STRIP_SYMBOLS := -s -w DISABLE_OPTIMIZATION := endif +#-------------------------------------- + +# Handle optional components and build profiles, to save space on the final binaries. + +# Keep it safe for now until we decide how to expand on the idea. Either choose a profile or exclude components manually. +# For example if we want to disable some component by default, or have opt-in components (INCLUDE?). + +ifeq ($(and $(BUILD_PROFILE),$(EXCLUDE)),1) +$(error "Cannot specify both BUILD_PROFILE and EXCLUDE") +endif + +COMPONENTS := \ + datasource_appsec \ + datasource_cloudwatch \ + datasource_docker \ + datasource_file \ + datasource_k8saudit \ + datasource_kafka \ + datasource_journalctl \ + datasource_kinesis \ + datasource_loki \ + datasource_s3 \ + datasource_syslog \ + datasource_wineventlog \ + cscli_setup + +comma := , +space := $(empty) $(empty) + +# Predefined profiles + +# keep only datasource-file +EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS))) + +# example +# EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3 + +BUILD_PROFILE ?= default + +# Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set +ifeq ($(BUILD_PROFILE),minimal) +EXCLUDE ?= $(EXCLUDE_MINIMAL) +else ifneq ($(BUILD_PROFILE),default) +$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default) +endif + +# Create list of excluded components from the EXCLUDE variable +EXCLUDE_LIST := $(subst $(comma),$(space),$(EXCLUDE)) + +INVALID_COMPONENTS := $(filter-out $(COMPONENTS),$(EXCLUDE_LIST)) +ifneq ($(INVALID_COMPONENTS),) +$(error Invalid optional components specified in EXCLUDE: $(INVALID_COMPONENTS). Valid components are: $(COMPONENTS)) +endif + +# Convert the excluded components to "no_" form +COMPONENT_TAGS := $(foreach component,$(EXCLUDE_LIST),no_$(component)) + +ifneq ($(COMPONENT_TAGS),) +GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS)) +endif + +#-------------------------------------- + export LD_OPTS=-ldflags "$(STRIP_SYMBOLS) $(EXTLDFLAGS) $(LD_OPTS_VARS)" \ -trimpath -tags $(GO_TAGS) $(DISABLE_OPTIMIZATION) @@ -128,11 +188,12 @@ endif #-------------------------------------- .PHONY: build -build: pre-build goversion crowdsec cscli plugins ## Build crowdsec, cscli and plugins +build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins -.PHONY: pre-build -pre-build: ## Sanity checks and build information +.PHONY: build-info +build-info: ## Print build information $(info Building $(BUILD_VERSION) ($(BUILD_TAG)) $(BUILD_TYPE) for $(GOOS)/$(GOARCH)) + $(info Excluded components: $(EXCLUDE_LIST)) ifneq (,$(RE2_FAIL)) $(error $(RE2_FAIL)) @@ -195,14 +256,13 @@ clean: clean-debian clean-rpm testclean ## Remove build artifacts ) .PHONY: cscli -cscli: goversion ## Build cscli +cscli: ## Build cscli @$(MAKE) -C $(CSCLI_FOLDER) build $(MAKE_FLAGS) .PHONY: crowdsec -crowdsec: goversion ## Build crowdsec +crowdsec: ## Build crowdsec @$(MAKE) -C $(CROWDSEC_FOLDER) build $(MAKE_FLAGS) - .PHONY: testclean testclean: bats-clean ## Remove test artifacts @$(RM) pkg/apiserver/ent $(WIN_IGNORE_ERR) @@ -216,24 +276,29 @@ export AWS_ACCESS_KEY_ID=test export AWS_SECRET_ACCESS_KEY=test testenv: - @echo 'NOTE: You need Docker, docker-compose and run "make localstack" in a separate shell ("make localstack-stop" to terminate it)' + @echo 'NOTE: You need to run "make localstack" in a separate shell, "make localstack-stop" to terminate it' .PHONY: test -test: testenv goversion ## Run unit tests with localstack - $(GOTEST) $(LD_OPTS) ./... +test: testenv ## Run unit tests with localstack + $(GOTEST) --tags=$(GO_TAGS) $(LD_OPTS) ./... .PHONY: go-acc -go-acc: testenv goversion ## Run unit tests with localstack + coverage - go-acc ./... -o coverage.out --ignore database,notifications,protobufs,cwversion,cstest,models -- $(LD_OPTS) +go-acc: testenv ## Run unit tests with localstack + coverage + go-acc ./... -o coverage.out --ignore database,notifications,protobufs,cwversion,cstest,models --tags $(GO_TAGS) -- $(LD_OPTS) + +check_docker: + @if ! docker info > /dev/null 2>&1; then \ + echo "Could not run 'docker info': check that docker is running, and if you need to run this command with sudo."; \ + fi # mock AWS services .PHONY: localstack -localstack: ## Run localstack containers (required for unit testing) - docker-compose -f test/localstack/docker-compose.yml up +localstack: check_docker ## Run localstack containers (required for unit testing) + docker compose -f test/localstack/docker-compose.yml up .PHONY: localstack-stop -localstack-stop: ## Stop localstack containers - docker-compose -f test/localstack/docker-compose.yml down +localstack-stop: check_docker ## Stop localstack containers + docker compose -f test/localstack/docker-compose.yml down # build vendor.tgz to be distributed with the release .PHONY: vendor @@ -296,5 +361,4 @@ else include test/bats.mk endif -include mk/goversion.mk include mk/help.mk diff --git a/README.md b/README.md index 6428c3a8053..a900f0ee514 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ +Go Reference diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 82caba42bae..6051ca67393 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -15,19 +15,13 @@ pool: stages: - stage: Build jobs: - - job: + - job: Build displayName: "Build" steps: - - task: DotNetCoreCLI@2 - displayName: "Install SignClient" - inputs: - command: 'custom' - custom: 'tool' - arguments: 'install --global SignClient --version 1.3.155' - task: GoTool@0 displayName: "Install Go" inputs: - version: '1.21.6' + version: '1.22' - pwsh: | choco install -y make @@ -39,24 +33,14 @@ stages: #we are not calling make windows_installer because we want to sign the binaries before they are added to the MSI script: | make build BUILD_RE2_WASM=1 - - task: AzureKeyVault@2 - inputs: - azureSubscription: 'Azure subscription 1(8a93ab40-7e99-445e-ad47-0f6a3e2ef546)' - KeyVaultName: 'CodeSigningSecrets' - SecretsFilter: 'CodeSigningUser,CodeSigningPassword' - RunAsPreJob: false - - - task: DownloadSecureFile@1 - inputs: - secureFile: appsettings.json - - - pwsh: | - SignClient.exe Sign --name "crowdsec-binaries" ` - --input "**/*.exe" --config (Join-Path -Path $(Agent.TempDirectory) -ChildPath "appsettings.json") ` - --user $(CodeSigningUser) --secret '$(CodeSigningPassword)' - displayName: "Sign Crowdsec binaries + plugins" + - pwsh: | $build_version=$env:BUILD_SOURCEBRANCHNAME + #Override the version if it's set in the pipeline + if ( ${env:USERBUILDVERSION} -ne "") + { + $build_version = ${env:USERBUILDVERSION} + } if ($build_version.StartsWith("v")) { $build_version = $build_version.Substring(1) @@ -69,35 +53,112 @@ stages: displayName: GetCrowdsecVersion name: GetCrowdsecVersion - pwsh: | - .\make_installer.ps1 -version '$(GetCrowdsecVersion.BuildVersion)' + Get-ChildItem -Path .\cmd -Directory | ForEach-Object { + $dirName = $_.Name + Get-ChildItem -Path .\cmd\$dirName -File -Filter '*.exe' | ForEach-Object { + $fileName = $_.Name + $destDir = Join-Path $(Build.ArtifactStagingDirectory) cmd\$dirName + New-Item -ItemType Directory -Path $destDir -Force + Copy-Item -Path .\cmd\$dirName\$fileName -Destination $destDir + } + } + displayName: "Copy binaries to staging directory" + - task: PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)' + artifact: 'unsigned_binaries' + displayName: "Upload binaries artifact" + + - stage: Sign + dependsOn: Build + variables: + - group: 'FOSS Build Variables' + - name: BuildVersion + value: $[ stageDependencies.Build.Build.outputs['GetCrowdsecVersion.BuildVersion'] ] + condition: succeeded() + jobs: + - job: Sign + displayName: "Sign" + steps: + - download: current + artifact: unsigned_binaries + displayName: "Download binaries artifact" + - task: CopyFiles@2 + inputs: + SourceFolder: '$(Pipeline.Workspace)/unsigned_binaries' + TargetFolder: '$(Build.SourcesDirectory)' + displayName: "Copy binaries to workspace" + - task: DotNetCoreCLI@2 + displayName: "Install SignTool tool" + inputs: + command: 'custom' + custom: 'tool' + arguments: install --global sign --version 0.9.0-beta.23127.3 + - task: AzureKeyVault@2 + displayName: "Get signing parameters" + inputs: + azureSubscription: "Azure subscription" + KeyVaultName: "$(KeyVaultName)" + SecretsFilter: "TenantId,ClientId,ClientSecret,Certificate,KeyVaultUrl" + - pwsh: | + sign code azure-key-vault ` + "**/*.exe" ` + --base-directory "$(Build.SourcesDirectory)/cmd/" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign crowdsec binaries" + - pwsh: | + .\make_installer.ps1 -version '$(BuildVersion)' displayName: "Build Crowdsec MSI" name: BuildMSI - - pwsh: | - .\make_chocolatey.ps1 -version '$(GetCrowdsecVersion.BuildVersion)' + .\make_chocolatey.ps1 -version '$(BuildVersion)' displayName: "Build Chocolatey nupkg" - - pwsh: | - SignClient.exe Sign --name "crowdsec-msi" ` - --input "*.msi" --config (Join-Path -Path $(Agent.TempDirectory) -ChildPath "appsettings.json") ` - --user $(CodeSigningUser) --secret '$(CodeSigningPassword)' - displayName: "Sign Crowdsec MSI" - - - task: PublishBuildArtifacts@1 + sign code azure-key-vault ` + "*.msi" ` + --base-directory "$(Build.SourcesDirectory)" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign MSI package" + - pwsh: | + sign code azure-key-vault ` + "*.nupkg" ` + --base-directory "$(Build.SourcesDirectory)" ` + --publisher-name "CrowdSec" ` + --description "CrowdSec" ` + --description-url "https://github.com/crowdsecurity/crowdsec" ` + --azure-key-vault-tenant-id "$(TenantId)" ` + --azure-key-vault-client-id "$(ClientId)" ` + --azure-key-vault-client-secret "$(ClientSecret)" ` + --azure-key-vault-certificate "$(Certificate)" ` + --azure-key-vault-url "$(KeyVaultUrl)" + displayName: "Sign nuget package" + - task: PublishPipelineArtifact@1 inputs: - PathtoPublish: '$(Build.Repository.LocalPath)\\crowdsec_$(GetCrowdsecVersion.BuildVersion).msi' - ArtifactName: 'crowdsec.msi' - publishLocation: 'Container' - displayName: "Upload MSI artifact" - - - task: PublishBuildArtifacts@1 + targetPath: '$(Build.SourcesDirectory)/crowdsec_$(BuildVersion).msi' + artifact: 'signed_msi_package' + displayName: "Upload signed MSI artifact" + - task: PublishPipelineArtifact@1 inputs: - PathtoPublish: '$(Build.Repository.LocalPath)\\windows\\Chocolatey\\crowdsec\\crowdsec.$(GetCrowdsecVersion.BuildVersion).nupkg' - ArtifactName: 'crowdsec.nupkg' - publishLocation: 'Container' - displayName: "Upload nupkg artifact" + targetPath: '$(Build.SourcesDirectory)/crowdsec.$(BuildVersion).nupkg' + artifact: 'signed_nuget_package' + displayName: "Upload signed nuget artifact" + - stage: Publish - dependsOn: Build + dependsOn: Sign jobs: - deployment: "Publish" displayName: "Publish to GitHub" @@ -119,8 +180,7 @@ stages: assetUploadMode: 'replace' addChangeLog: false isPreRelease: true #we force prerelease because the pipeline is invoked on tag creation, which happens when we do a prerelease - #the .. is an ugly hack, but I can't find the var that gives D:\a\1 ... assets: | - $(Build.ArtifactStagingDirectory)\..\crowdsec.msi/*.msi - $(Build.ArtifactStagingDirectory)\..\crowdsec.nupkg/*.nupkg + $(Pipeline.Workspace)/signed_msi_package/*.msi + $(Pipeline.Workspace)/signed_nuget_package/*.nupkg condition: ne(variables['GetLatestPrelease.LatestPreRelease'], '') diff --git a/cmd/crowdsec-cli/Makefile b/cmd/crowdsec-cli/Makefile index 392361ef82e..6d6e4da8dbd 100644 --- a/cmd/crowdsec-cli/Makefile +++ b/cmd/crowdsec-cli/Makefile @@ -8,8 +8,6 @@ GO = go GOBUILD = $(GO) build BINARY_NAME = cscli$(EXT) -PREFIX ?= "/" -BIN_PREFIX = $(PREFIX)"/usr/local/bin/" .PHONY: all all: clean build @@ -17,17 +15,5 @@ all: clean build build: clean $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) -.PHONY: install -install: install-conf install-bin - -install-conf: - -install-bin: - @install -v -m 755 -D "$(BINARY_NAME)" "$(BIN_PREFIX)/$(BINARY_NAME)" || exit - -uninstall: - @$(RM) $(CSCLI_CONFIG) $(WIN_IGNORE_ERR) - @$(RM) $(BIN_PREFIX)$(BINARY_NAME) $(WIN_IGNORE_ERR) - clean: @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/cmd/crowdsec-cli/alerts.go b/cmd/crowdsec-cli/alerts.go deleted file mode 100644 index 15824d2d067..00000000000 --- a/cmd/crowdsec-cli/alerts.go +++ /dev/null @@ -1,565 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "net/url" - "os" - "sort" - "strconv" - "strings" - "text/template" - "time" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/version" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func DecisionsFromAlert(alert *models.Alert) string { - ret := "" - var decMap = make(map[string]int) - for _, decision := range alert.Decisions { - k := *decision.Type - if *decision.Simulated { - k = fmt.Sprintf("(simul)%s", k) - } - v := decMap[k] - decMap[k] = v + 1 - } - for k, v := range decMap { - if len(ret) > 0 { - ret += " " - } - ret += fmt.Sprintf("%s:%d", k, v) - } - return ret -} - -func DateFromAlert(alert *models.Alert) string { - ts, err := time.Parse(time.RFC3339, alert.CreatedAt) - if err != nil { - log.Infof("while parsing %s with %s : %s", alert.CreatedAt, time.RFC3339, err) - return alert.CreatedAt - } - return ts.Format(time.RFC822) -} - -func SourceFromAlert(alert *models.Alert) string { - - //more than one item, just number and scope - if len(alert.Decisions) > 1 { - return fmt.Sprintf("%d %ss (%s)", len(alert.Decisions), *alert.Decisions[0].Scope, *alert.Decisions[0].Origin) - } - - //fallback on single decision information - if len(alert.Decisions) == 1 { - return fmt.Sprintf("%s:%s", *alert.Decisions[0].Scope, *alert.Decisions[0].Value) - } - - //try to compose a human friendly version - if *alert.Source.Value != "" && *alert.Source.Scope != "" { - scope := fmt.Sprintf("%s:%s", *alert.Source.Scope, *alert.Source.Value) - extra := "" - if alert.Source.Cn != "" { - extra = alert.Source.Cn - } - if alert.Source.AsNumber != "" { - extra += fmt.Sprintf("/%s", alert.Source.AsNumber) - } - if alert.Source.AsName != "" { - extra += fmt.Sprintf("/%s", alert.Source.AsName) - } - - if extra != "" { - scope += " (" + extra + ")" - } - return scope - } - return "" -} - -func AlertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { - - if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(os.Stdout) - header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"} - if printMachine { - header = append(header, "machine") - } - err := csvwriter.Write(header) - if err != nil { - return err - } - for _, alertItem := range *alerts { - row := []string{ - fmt.Sprintf("%d", alertItem.ID), - *alertItem.Source.Scope, - *alertItem.Source.Value, - *alertItem.Scenario, - alertItem.Source.Cn, - alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), - *alertItem.StartAt, - } - if printMachine { - row = append(row, alertItem.MachineID) - } - err := csvwriter.Write(row) - if err != nil { - return err - } - } - csvwriter.Flush() - } else if csConfig.Cscli.Output == "json" { - if *alerts == nil { - // avoid returning "null" in json - // could be cleaner if we used slice of alerts directly - fmt.Println("[]") - return nil - } - x, _ := json.MarshalIndent(alerts, "", " ") - fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "human" { - if len(*alerts) == 0 { - fmt.Println("No active alerts") - return nil - } - alertsTable(color.Output, alerts, printMachine) - } - return nil -} - -var alertTemplate = ` -################################################################################################ - - - ID : {{.ID}} - - Date : {{.CreatedAt}} - - Machine : {{.MachineID}} - - Simulation : {{.Simulated}} - - Reason : {{.Scenario}} - - Events Count : {{.EventsCount}} - - Scope:Value : {{.Source.Scope}}{{if .Source.Value}}:{{.Source.Value}}{{end}} - - Country : {{.Source.Cn}} - - AS : {{.Source.AsName}} - - Begin : {{.StartAt}} - - End : {{.StopAt}} - - UUID : {{.UUID}} - -` - -func DisplayOneAlert(alert *models.Alert, withDetail bool) error { - if csConfig.Cscli.Output == "human" { - tmpl, err := template.New("alert").Parse(alertTemplate) - if err != nil { - return err - } - err = tmpl.Execute(os.Stdout, alert) - if err != nil { - return err - } - - alertDecisionsTable(color.Output, alert) - - if len(alert.Meta) > 0 { - fmt.Printf("\n - Context :\n") - sort.Slice(alert.Meta, func(i, j int) bool { - return alert.Meta[i].Key < alert.Meta[j].Key - }) - table := newTable(color.Output) - table.SetRowLines(false) - table.SetHeaders("Key", "Value") - for _, meta := range alert.Meta { - var valSlice []string - if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil { - return fmt.Errorf("unknown context value type '%s' : %s", meta.Value, err) - } - for _, value := range valSlice { - table.AddRow( - meta.Key, - value, - ) - } - } - table.Render() - } - - if withDetail { - fmt.Printf("\n - Events :\n") - for _, event := range alert.Events { - alertEventTable(color.Output, event) - } - } - } - return nil -} - -type cliAlerts struct{} - -func NewCLIAlerts() *cliAlerts { - return &cliAlerts{} -} - -func (cli cliAlerts) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "alerts [action]", - Short: "Manage alerts", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - Aliases: []string{"alert"}, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url %s: %w", apiURL, err) - } - Client, err = apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - VersionPrefix: "v1", - }) - - if err != nil { - return fmt.Errorf("new api client: %w", err) - } - return nil - }, - } - - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewInspectCmd()) - cmd.AddCommand(cli.NewFlushCmd()) - cmd.AddCommand(cli.NewDeleteCmd()) - - return cmd -} - -func (cli cliAlerts) NewListCmd() *cobra.Command { - var alertListFilter = apiclient.AlertsListOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - ScenarioEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - Since: new(string), - Until: new(string), - TypeEquals: new(string), - IncludeCAPI: new(bool), - OriginEquals: new(string), - } - limit := new(int) - contained := new(bool) - var printMachine bool - - cmd := &cobra.Command{ - Use: "list [filters]", - Short: "List alerts", - Example: `cscli alerts list -cscli alerts list --ip 1.2.3.4 -cscli alerts list --range 1.2.3.0/24 -cscli alerts list -s crowdsecurity/ssh-bf -cscli alerts list --type ban`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - - if err := manageCliDecisionAlerts(alertListFilter.IPEquals, alertListFilter.RangeEquals, - alertListFilter.ScopeEquals, alertListFilter.ValueEquals); err != nil { - printHelp(cmd) - return err - } - if limit != nil { - alertListFilter.Limit = limit - } - - if *alertListFilter.Until == "" { - alertListFilter.Until = nil - } else if strings.HasSuffix(*alertListFilter.Until, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*alertListFilter.Until, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Until) - } - *alertListFilter.Until = fmt.Sprintf("%d%s", days*24, "h") - } - if *alertListFilter.Since == "" { - alertListFilter.Since = nil - } else if strings.HasSuffix(*alertListFilter.Since, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*alertListFilter.Since, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Since) - } - *alertListFilter.Since = fmt.Sprintf("%d%s", days*24, "h") - } - - if *alertListFilter.IncludeCAPI { - *alertListFilter.Limit = 0 - } - - if *alertListFilter.TypeEquals == "" { - alertListFilter.TypeEquals = nil - } - if *alertListFilter.ScopeEquals == "" { - alertListFilter.ScopeEquals = nil - } - if *alertListFilter.ValueEquals == "" { - alertListFilter.ValueEquals = nil - } - if *alertListFilter.ScenarioEquals == "" { - alertListFilter.ScenarioEquals = nil - } - if *alertListFilter.IPEquals == "" { - alertListFilter.IPEquals = nil - } - if *alertListFilter.RangeEquals == "" { - alertListFilter.RangeEquals = nil - } - - if *alertListFilter.OriginEquals == "" { - alertListFilter.OriginEquals = nil - } - - if contained != nil && *contained { - alertListFilter.Contains = new(bool) - } - - alerts, _, err := Client.Alerts.List(context.Background(), alertListFilter) - if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) - } - - err = AlertsToTable(alerts, printMachine) - if err != nil { - return fmt.Errorf("unable to list alerts: %v", err) - } - - return nil - }, - } - cmd.Flags().SortFlags = false - cmd.Flags().BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") - cmd.Flags().StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") - cmd.Flags().StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") - cmd.Flags().StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") - cmd.Flags().StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") - cmd.Flags().StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") - cmd.Flags().StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmd.Flags().StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - cmd.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") - cmd.Flags().IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") - - return cmd -} - -func (cli cliAlerts) NewDeleteCmd() *cobra.Command { - var ActiveDecision *bool - var AlertDeleteAll bool - var delAlertByID string - contained := new(bool) - var alertDeleteFilter = apiclient.AlertsDeleteOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - ScenarioEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - } - cmd := &cobra.Command{ - Use: "delete [filters] [--all]", - Short: `Delete alerts -/!\ This command can be use only on the same machine than the local API.`, - Example: `cscli alerts delete --ip 1.2.3.4 -cscli alerts delete --range 1.2.3.0/24 -cscli alerts delete -s crowdsecurity/ssh-bf"`, - DisableAutoGenTag: true, - Aliases: []string{"remove"}, - Args: cobra.ExactArgs(0), - PreRunE: func(cmd *cobra.Command, args []string) error { - if AlertDeleteAll { - return nil - } - if *alertDeleteFilter.ScopeEquals == "" && *alertDeleteFilter.ValueEquals == "" && - *alertDeleteFilter.ScenarioEquals == "" && *alertDeleteFilter.IPEquals == "" && - *alertDeleteFilter.RangeEquals == "" && delAlertByID == "" { - _ = cmd.Usage() - return fmt.Errorf("at least one filter or --all must be specified") - } - - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - - if !AlertDeleteAll { - if err := manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals, - alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil { - printHelp(cmd) - return err - } - if ActiveDecision != nil { - alertDeleteFilter.ActiveDecisionEquals = ActiveDecision - } - - if *alertDeleteFilter.ScopeEquals == "" { - alertDeleteFilter.ScopeEquals = nil - } - if *alertDeleteFilter.ValueEquals == "" { - alertDeleteFilter.ValueEquals = nil - } - if *alertDeleteFilter.ScenarioEquals == "" { - alertDeleteFilter.ScenarioEquals = nil - } - if *alertDeleteFilter.IPEquals == "" { - alertDeleteFilter.IPEquals = nil - } - if *alertDeleteFilter.RangeEquals == "" { - alertDeleteFilter.RangeEquals = nil - } - if contained != nil && *contained { - alertDeleteFilter.Contains = new(bool) - } - limit := 0 - alertDeleteFilter.Limit = &limit - } else { - limit := 0 - alertDeleteFilter = apiclient.AlertsDeleteOpts{Limit: &limit} - } - - var alerts *models.DeleteAlertsResponse - if delAlertByID == "" { - alerts, _, err = Client.Alerts.Delete(context.Background(), alertDeleteFilter) - if err != nil { - return fmt.Errorf("unable to delete alerts : %v", err) - } - } else { - alerts, _, err = Client.Alerts.DeleteOne(context.Background(), delAlertByID) - if err != nil { - return fmt.Errorf("unable to delete alert: %v", err) - } - } - log.Infof("%s alert(s) deleted", alerts.NbDeleted) - - return nil - }, - } - cmd.Flags().SortFlags = false - cmd.Flags().StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") - cmd.Flags().StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmd.Flags().StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmd.Flags().StringVar(&delAlertByID, "id", "", "alert ID") - cmd.Flags().BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts") - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - return cmd -} - -func (cli cliAlerts) NewInspectCmd() *cobra.Command { - var details bool - cmd := &cobra.Command{ - Use: `inspect "alert_id"`, - Short: `Show info about an alert`, - Example: `cscli alerts inspect 123`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 0 { - printHelp(cmd) - return fmt.Errorf("missing alert_id") - } - for _, alertID := range args { - id, err := strconv.Atoi(alertID) - if err != nil { - return fmt.Errorf("bad alert id %s", alertID) - } - alert, _, err := Client.Alerts.GetByID(context.Background(), id) - if err != nil { - return fmt.Errorf("can't find alert with id %s: %s", alertID, err) - } - switch csConfig.Cscli.Output { - case "human": - if err := DisplayOneAlert(alert, details); err != nil { - continue - } - case "json": - data, err := json.MarshalIndent(alert, "", " ") - if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) - } - fmt.Printf("%s\n", string(data)) - case "raw": - data, err := yaml.Marshal(alert) - if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err) - } - fmt.Printf("%s\n", string(data)) - } - } - - return nil - }, - } - cmd.Flags().SortFlags = false - cmd.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events") - - return cmd -} - -func (cli cliAlerts) NewFlushCmd() *cobra.Command { - var maxItems int - var maxAge string - cmd := &cobra.Command{ - Use: `flush`, - Short: `Flush alerts -/!\ This command can be used only on the same machine than the local API`, - Example: `cscli alerts flush --max-items 1000 --max-age 7d`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - var err error - if err := require.LAPI(csConfig); err != nil { - return err - } - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - log.Info("Flushing alerts. !! This may take a long time !!") - err = dbClient.FlushAlerts(maxAge, maxItems) - if err != nil { - return fmt.Errorf("unable to flush alerts: %s", err) - } - log.Info("Alerts flushed") - - return nil - }, - } - - cmd.Flags().SortFlags = false - cmd.Flags().IntVar(&maxItems, "max-items", 5000, "Maximum number of alert items to keep in the database") - cmd.Flags().StringVar(&maxAge, "max-age", "7d", "Maximum age of alert items to keep in the database") - - return cmd -} diff --git a/cmd/crowdsec-cli/ask/ask.go b/cmd/crowdsec-cli/ask/ask.go new file mode 100644 index 00000000000..484ccb30c8a --- /dev/null +++ b/cmd/crowdsec-cli/ask/ask.go @@ -0,0 +1,20 @@ +package ask + +import ( + "github.com/AlecAivazis/survey/v2" +) + +func YesNo(message string, defaultAnswer bool) (bool, error) { + var answer bool + + prompt := &survey.Confirm{ + Message: message, + Default: defaultAnswer, + } + + if err := survey.AskOne(prompt, &answer); err != nil { + return defaultAnswer, err + } + + return answer, nil +} diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go deleted file mode 100644 index 410827b3159..00000000000 --- a/cmd/crowdsec-cli/bouncers.go +++ /dev/null @@ -1,317 +0,0 @@ -package main - -import ( - "encoding/csv" - "encoding/json" - "fmt" - "os" - "slices" - "strings" - "time" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func askYesNo(message string, defaultAnswer bool) (bool, error) { - var answer bool - - prompt := &survey.Confirm{ - Message: message, - Default: defaultAnswer, - } - - if err := survey.AskOne(prompt, &answer); err != nil { - return defaultAnswer, err - } - - return answer, nil -} - -type cliBouncers struct { - db *database.Client - cfg func() *csconfig.Config -} - -func NewCLIBouncers(getconfig func() *csconfig.Config) *cliBouncers { - return &cliBouncers{ - cfg: getconfig, - } -} - -func (cli *cliBouncers) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "bouncers [action]", - Short: "Manage bouncers [requires local API]", - Long: `To list/add/delete/prune bouncers. -Note: This command requires database direct access, so is intended to be run on Local API/master. -`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"bouncer"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - var err error - if err = require.LAPI(cli.cfg()); err != nil { - return err - } - - cli.db, err = database.NewClient(cli.cfg().DbConfig) - if err != nil { - return fmt.Errorf("can't connect to the database: %s", err) - } - - return nil - }, - } - - cmd.AddCommand(cli.newListCmd()) - cmd.AddCommand(cli.newAddCmd()) - cmd.AddCommand(cli.newDeleteCmd()) - cmd.AddCommand(cli.newPruneCmd()) - - return cmd -} - -func (cli *cliBouncers) list() error { - out := color.Output - - bouncers, err := cli.db.ListBouncers() - if err != nil { - return fmt.Errorf("unable to list bouncers: %s", err) - } - - switch cli.cfg().Cscli.Output { - case "human": - getBouncersTable(out, bouncers) - case "json": - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(bouncers); err != nil { - return fmt.Errorf("failed to marshal: %w", err) - } - - return nil - case "raw": - csvwriter := csv.NewWriter(out) - - if err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}); err != nil { - return fmt.Errorf("failed to write raw header: %w", err) - } - - for _, b := range bouncers { - valid := "validated" - if b.Revoked { - valid = "pending" - } - - if err := csvwriter.Write([]string{b.Name, b.IPAddress, valid, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType}); err != nil { - return fmt.Errorf("failed to write raw: %w", err) - } - } - - csvwriter.Flush() - } - - return nil -} - -func (cli *cliBouncers) newListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list all bouncers within the database", - Example: `cscli bouncers list`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.list() - }, - } - - return cmd -} - -func (cli *cliBouncers) add(bouncerName string, key string) error { - var err error - - keyLength := 32 - - if key == "" { - key, err = middlewares.GenerateAPIKey(keyLength) - if err != nil { - return fmt.Errorf("unable to generate api key: %s", err) - } - } - - _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) - if err != nil { - return fmt.Errorf("unable to create bouncer: %s", err) - } - - switch cli.cfg().Cscli.Output { - case "human": - fmt.Printf("API key for '%s':\n\n", bouncerName) - fmt.Printf(" %s\n\n", key) - fmt.Print("Please keep this key since you will not be able to retrieve it!\n") - case "raw": - fmt.Print(key) - case "json": - j, err := json.Marshal(key) - if err != nil { - return fmt.Errorf("unable to marshal api key") - } - - fmt.Print(string(j)) - } - - return nil -} - -func (cli *cliBouncers) newAddCmd() *cobra.Command { - var key string - - cmd := &cobra.Command{ - Use: "add MyBouncerName", - Short: "add a single bouncer to the database", - Example: `cscli bouncers add MyBouncerName -cscli bouncers add MyBouncerName --key `, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args[0], key) - }, - } - - flags := cmd.Flags() - flags.StringP("length", "l", "", "length of the api key") - flags.MarkDeprecated("length", "use --key instead") - flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") - - return cmd -} - -func (cli *cliBouncers) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - bouncers, err := cli.db.ListBouncers() - if err != nil { - cobra.CompError("unable to list bouncers " + err.Error()) - } - - ret :=[]string{} - - for _, bouncer := range bouncers { - if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { - ret = append(ret, bouncer.Name) - } - } - - return ret, cobra.ShellCompDirectiveNoFileComp -} - -func (cli *cliBouncers) delete(bouncers []string) error { - for _, bouncerID := range bouncers { - err := cli.db.DeleteBouncer(bouncerID) - if err != nil { - return fmt.Errorf("unable to delete bouncer '%s': %s", bouncerID, err) - } - - log.Infof("bouncer '%s' deleted successfully", bouncerID) - } - - return nil -} - -func (cli *cliBouncers) newDeleteCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "delete MyBouncerName", - Short: "delete bouncer(s) from the database", - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - ValidArgsFunction: cli.deleteValid, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args) - }, - } - - return cmd -} - -func (cli *cliBouncers) prune(duration time.Duration, force bool) error { - if duration < 2*time.Minute { - if yes, err := askYesNo( - "The duration you provided is less than 2 minutes. " + - "This may remove active bouncers. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - bouncers, err := cli.db.QueryBouncersLastPulltimeLT(time.Now().UTC().Add(duration)) - if err != nil { - return fmt.Errorf("unable to query bouncers: %w", err) - } - - if len(bouncers) == 0 { - fmt.Println("No bouncers to prune.") - return nil - } - - getBouncersTable(color.Output, bouncers) - - if !force { - if yes, err := askYesNo( - "You are about to PERMANENTLY remove the above bouncers from the database. " + - "These will NOT be recoverable. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - deleted, err := cli.db.BulkDeleteBouncers(bouncers) - if err != nil { - return fmt.Errorf("unable to prune bouncers: %s", err) - } - - fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) - - return nil -} - -func (cli *cliBouncers) newPruneCmd() *cobra.Command { - var ( - duration time.Duration - force bool - ) - - const defaultDuration = 60 * time.Minute - - cmd := &cobra.Command{ - Use: "prune", - Short: "prune multiple bouncers from the database", - Args: cobra.NoArgs, - DisableAutoGenTag: true, - Example: `cscli bouncers prune -d 45m -cscli bouncers prune -d 45m --force`, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, force) - }, - } - - flags := cmd.Flags() - flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since last pull") - flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") - - return cmd -} diff --git a/cmd/crowdsec-cli/bouncers_table.go b/cmd/crowdsec-cli/bouncers_table.go deleted file mode 100644 index 0ea725f5598..00000000000 --- a/cmd/crowdsec-cli/bouncers_table.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -import ( - "io" - "time" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/database/ent" -) - -func getBouncersTable(out io.Writer, bouncers []*ent.Bouncer) { - t := newLightTable(out) - t.SetHeaders("Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, b := range bouncers { - var revoked string - if !b.Revoked { - revoked = emoji.CheckMark.String() - } else { - revoked = emoji.Prohibited.String() - } - - t.AddRow(b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go deleted file mode 100644 index 358d91ee215..00000000000 --- a/cmd/crowdsec-cli/capi.go +++ /dev/null @@ -1,196 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net/url" - "os" - - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/go-cs-lib/version" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -const ( - CAPIBaseURL = "https://api.crowdsec.net/" - CAPIURLPrefix = "v3" -) - -type cliCapi struct{} - -func NewCLICapi() *cliCapi { - return &cliCapi{} -} - -func (cli cliCapi) NewCommand() *cobra.Command { - var cmd = &cobra.Command{ - Use: "capi [action]", - Short: "Manage interaction with Central API (CAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - if err := require.LAPI(csConfig); err != nil { - return err - } - - if err := require.CAPI(csConfig); err != nil { - return err - } - - return nil - }, - } - - cmd.AddCommand(cli.NewRegisterCmd()) - cmd.AddCommand(cli.NewStatusCmd()) - - return cmd -} - -func (cli cliCapi) NewRegisterCmd() *cobra.Command { - var ( - capiUserPrefix string - outputFile string - ) - - var cmd = &cobra.Command{ - Use: "register", - Short: "Register to Central API (CAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - var err error - capiUser, err := generateID(capiUserPrefix) - if err != nil { - return fmt.Errorf("unable to generate machine id: %s", err) - } - password := strfmt.Password(generatePassword(passwordLength)) - apiurl, err := url.Parse(types.CAPIBaseURL) - if err != nil { - return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) - } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: capiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: CAPIURLPrefix, - }, nil) - - if err != nil { - return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) - } - log.Printf("Successfully registered to Central API (CAPI)") - - var dumpFile string - - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - dumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } else { - dumpFile = "" - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: capiUser, - Password: password.String(), - URL: types.CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) - } - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) - } - log.Printf("Central API credentials written to '%s'", dumpFile) - } else { - fmt.Println(string(apiConfigDump)) - } - - log.Warning(ReloadMessage()) - - return nil - }, - } - - cmd.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination") - cmd.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)") - - if err := cmd.Flags().MarkHidden("schmilblick"); err != nil { - log.Fatalf("failed to hide flag: %s", err) - } - - return cmd -} - -func (cli cliCapi) NewStatusCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "Check status with the Central API (CAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - if err := require.CAPIRegistered(csConfig); err != nil { - return err - } - - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - - apiurl, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", csConfig.API.Server.OnlineClient.Credentials.URL, err) - } - - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - return err - } - - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get scenarios: %w", err) - } - - if len(scenarios) == 0 { - return fmt.Errorf("no scenarios installed, abort") - } - - Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) - if err != nil { - return fmt.Errorf("init default client: %w", err) - } - - t := models.WatcherAuthRequest{ - MachineID: &csConfig.API.Server.OnlineClient.Credentials.Login, - Password: &password, - Scenarios: scenarios, - } - - log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl) - - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) - } - log.Infof("You can successfully interact with Central API (CAPI)") - - return nil - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/clialert/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go new file mode 100644 index 00000000000..75454e945f2 --- /dev/null +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -0,0 +1,603 @@ +package clialert + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "sort" + "strconv" + "strings" + "text/template" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func decisionsFromAlert(alert *models.Alert) string { + ret := "" + decMap := make(map[string]int) + + for _, decision := range alert.Decisions { + k := *decision.Type + if *decision.Simulated { + k = fmt.Sprintf("(simul)%s", k) + } + + v := decMap[k] + decMap[k] = v + 1 + } + + for _, key := range maptools.SortedKeys(decMap) { + if ret != "" { + ret += " " + } + + ret += fmt.Sprintf("%s:%d", key, decMap[key]) + } + + return ret +} + +func (cli *cliAlerts) alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { + cfg := cli.cfg() + switch cfg.Cscli.Output { + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"} + + if printMachine { + header = append(header, "machine") + } + + if err := csvwriter.Write(header); err != nil { + return err + } + + for _, alertItem := range *alerts { + row := []string{ + strconv.FormatInt(alertItem.ID, 10), + *alertItem.Source.Scope, + *alertItem.Source.Value, + *alertItem.Scenario, + alertItem.Source.Cn, + alertItem.Source.GetAsNumberName(), + decisionsFromAlert(alertItem), + *alertItem.StartAt, + } + if printMachine { + row = append(row, alertItem.MachineID) + } + + if err := csvwriter.Write(row); err != nil { + return err + } + } + + csvwriter.Flush() + case "json": + if *alerts == nil { + // avoid returning "null" in json + // could be cleaner if we used slice of alerts directly + fmt.Println("[]") + return nil + } + + x, _ := json.MarshalIndent(alerts, "", " ") + fmt.Print(string(x)) + case "human": + if len(*alerts) == 0 { + fmt.Println("No active alerts") + return nil + } + + alertsTable(color.Output, cfg.Cscli.Color, alerts, printMachine) + } + + return nil +} + +func (cli *cliAlerts) displayOneAlert(alert *models.Alert, withDetail bool) error { + alertTemplate := ` +################################################################################################ + + - ID : {{.ID}} + - Date : {{.CreatedAt}} + - Machine : {{.MachineID}} + - Simulation : {{.Simulated}} + - Remediation : {{.Remediation}} + - Reason : {{.Scenario}} + - Events Count : {{.EventsCount}} + - Scope:Value : {{.Source.Scope}}{{if .Source.Value}}:{{.Source.Value}}{{end}} + - Country : {{.Source.Cn}} + - AS : {{.Source.AsName}} + - Begin : {{.StartAt}} + - End : {{.StopAt}} + - UUID : {{.UUID}} + +` + + tmpl, err := template.New("alert").Parse(alertTemplate) + if err != nil { + return err + } + + if err = tmpl.Execute(os.Stdout, alert); err != nil { + return err + } + + cfg := cli.cfg() + + alertDecisionsTable(color.Output, cfg.Cscli.Color, alert) + + if len(alert.Meta) > 0 { + fmt.Printf("\n - Context :\n") + sort.Slice(alert.Meta, func(i, j int) bool { + return alert.Meta[i].Key < alert.Meta[j].Key + }) + + table := cstable.New(color.Output, cfg.Cscli.Color) + table.SetRowLines(false) + table.SetHeaders("Key", "Value") + + for _, meta := range alert.Meta { + var valSlice []string + if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil { + return fmt.Errorf("unknown context value type '%s': %w", meta.Value, err) + } + + for _, value := range valSlice { + table.AddRow( + meta.Key, + value, + ) + } + } + + table.Render() + } + + if withDetail { + fmt.Printf("\n - Events :\n") + + for _, event := range alert.Events { + alertEventTable(color.Output, cfg.Cscli.Color, event) + } + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliAlerts struct { + client *apiclient.ApiClient + cfg configGetter +} + +func New(getconfig configGetter) *cliAlerts { + return &cliAlerts{ + cfg: getconfig, + } +} + +func (cli *cliAlerts) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "alerts [action]", + Short: "Manage alerts", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + Aliases: []string{"alert"}, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newFlushCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + + return cmd +} + +func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { + var err error + + *alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals) + if err != nil { + return err + } + + if limit != nil { + alertListFilter.Limit = limit + } + + if *alertListFilter.Until == "" { + alertListFilter.Until = nil + } else if strings.HasSuffix(*alertListFilter.Until, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*alertListFilter.Until, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Until) + } + + *alertListFilter.Until = fmt.Sprintf("%d%s", days*24, "h") + } + + if *alertListFilter.Since == "" { + alertListFilter.Since = nil + } else if strings.HasSuffix(*alertListFilter.Since, "d") { + // time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier + realDuration := strings.TrimSuffix(*alertListFilter.Since, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *alertListFilter.Since) + } + + *alertListFilter.Since = fmt.Sprintf("%d%s", days*24, "h") + } + + if *alertListFilter.IncludeCAPI { + *alertListFilter.Limit = 0 + } + + if *alertListFilter.TypeEquals == "" { + alertListFilter.TypeEquals = nil + } + + if *alertListFilter.ScopeEquals == "" { + alertListFilter.ScopeEquals = nil + } + + if *alertListFilter.ValueEquals == "" { + alertListFilter.ValueEquals = nil + } + + if *alertListFilter.ScenarioEquals == "" { + alertListFilter.ScenarioEquals = nil + } + + if *alertListFilter.IPEquals == "" { + alertListFilter.IPEquals = nil + } + + if *alertListFilter.RangeEquals == "" { + alertListFilter.RangeEquals = nil + } + + if *alertListFilter.OriginEquals == "" { + alertListFilter.OriginEquals = nil + } + + if contained != nil && *contained { + alertListFilter.Contains = new(bool) + } + + alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter) + if err != nil { + return fmt.Errorf("unable to list alerts: %w", err) + } + + if err = cli.alertsToTable(alerts, printMachine); err != nil { + return fmt.Errorf("unable to list alerts: %w", err) + } + + return nil +} + +func (cli *cliAlerts) newListCmd() *cobra.Command { + alertListFilter := apiclient.AlertsListOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + ScenarioEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + Since: new(string), + Until: new(string), + TypeEquals: new(string), + IncludeCAPI: new(bool), + OriginEquals: new(string), + } + + limit := new(int) + contained := new(bool) + + var printMachine bool + + cmd := &cobra.Command{ + Use: "list [filters]", + Short: "List alerts", + Example: `cscli alerts list +cscli alerts list --ip 1.2.3.4 +cscli alerts list --range 1.2.3.0/24 +cscli alerts list --origin lists +cscli alerts list -s crowdsecurity/ssh-bf +cscli alerts list --type ban`, + Long: `List alerts with optional filters`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") + flags.StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") + flags.StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") + flags.StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") + flags.StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value )") + flags.StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)") + flags.StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)") + flags.StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts") + flags.IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)") + + return cmd +} + +func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { + var err error + + if !deleteAll { + *delFilter.ScopeEquals, err = SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { + return err + } + + if activeDecision != nil { + delFilter.ActiveDecisionEquals = activeDecision + } + + if *delFilter.ScopeEquals == "" { + delFilter.ScopeEquals = nil + } + + if *delFilter.ValueEquals == "" { + delFilter.ValueEquals = nil + } + + if *delFilter.ScenarioEquals == "" { + delFilter.ScenarioEquals = nil + } + + if *delFilter.IPEquals == "" { + delFilter.IPEquals = nil + } + + if *delFilter.RangeEquals == "" { + delFilter.RangeEquals = nil + } + + if contained != nil && *contained { + delFilter.Contains = new(bool) + } + + limit := 0 + delFilter.Limit = &limit + } else { + limit := 0 + delFilter = apiclient.AlertsDeleteOpts{Limit: &limit} + } + + var alerts *models.DeleteAlertsResponse + if delAlertByID == "" { + alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter) + if err != nil { + return fmt.Errorf("unable to delete alerts: %w", err) + } + } else { + alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID) + if err != nil { + return fmt.Errorf("unable to delete alert: %w", err) + } + } + + log.Infof("%s alert(s) deleted", alerts.NbDeleted) + + return nil +} + +func (cli *cliAlerts) newDeleteCmd() *cobra.Command { + var ( + activeDecision *bool + deleteAll bool + delAlertByID string + ) + + delFilter := apiclient.AlertsDeleteOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + ScenarioEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + } + + contained := new(bool) + + cmd := &cobra.Command{ + Use: "delete [filters] [--all]", + Short: `Delete alerts +/!\ This command can be use only on the same machine than the local API.`, + Example: `cscli alerts delete --ip 1.2.3.4 +cscli alerts delete --range 1.2.3.0/24 +cscli alerts delete -s crowdsecurity/ssh-bf"`, + DisableAutoGenTag: true, + Aliases: []string{"remove"}, + Args: cobra.ExactArgs(0), + PreRunE: func(cmd *cobra.Command, _ []string) error { + if deleteAll { + return nil + } + if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && + *delFilter.ScenarioEquals == "" && *delFilter.IPEquals == "" && + *delFilter.RangeEquals == "" && delAlertByID == "" { + _ = cmd.Usage() + return errors.New("at least one filter or --all must be specified") + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVar(delFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") + flags.StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVar(&delAlertByID, "id", "", "alert ID") + flags.BoolVarP(&deleteAll, "all", "a", false, "delete all alerts") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} + +func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error { + cfg := cli.cfg() + + for _, alertID := range alertIDs { + id, err := strconv.Atoi(alertID) + if err != nil { + return fmt.Errorf("bad alert id %s", alertID) + } + + alert, _, err := cli.client.Alerts.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("can't find alert with id %s: %w", alertID, err) + } + + switch cfg.Cscli.Output { + case "human": + if err := cli.displayOneAlert(alert, details); err != nil { + log.Warnf("unable to display alert with id %s: %s", alertID, err) + continue + } + case "json": + data, err := json.MarshalIndent(alert, "", " ") + if err != nil { + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) + } + + fmt.Printf("%s\n", string(data)) + case "raw": + data, err := yaml.Marshal(alert) + if err != nil { + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) + } + + fmt.Println(string(data)) + } + } + + return nil +} + +func (cli *cliAlerts) newInspectCmd() *cobra.Command { + var details bool + + cmd := &cobra.Command{ + Use: `inspect "alert_id"`, + Short: `Show info about an alert`, + Example: `cscli alerts inspect 123`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + _ = cmd.Help() + return errors.New("missing alert_id") + } + return cli.inspect(cmd.Context(), details, args...) + }, + } + + cmd.Flags().SortFlags = false + cmd.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events") + + return cmd +} + +func (cli *cliAlerts) newFlushCmd() *cobra.Command { + var ( + maxItems int + maxAge string + ) + + cmd := &cobra.Command{ + Use: `flush`, + Short: `Flush alerts +/!\ This command can be used only on the same machine than the local API`, + Example: `cscli alerts flush --max-items 1000 --max-age 7d`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + if err := require.LAPI(cfg); err != nil { + return err + } + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + log.Info("Flushing alerts. !! This may take a long time !!") + err = db.FlushAlerts(ctx, maxAge, maxItems) + if err != nil { + return fmt.Errorf("unable to flush alerts: %w", err) + } + log.Info("Alerts flushed") + + return nil + }, + } + + cmd.Flags().SortFlags = false + cmd.Flags().IntVar(&maxItems, "max-items", 5000, "Maximum number of alert items to keep in the database") + cmd.Flags().StringVar(&maxAge, "max-age", "7d", "Maximum age of alert items to keep in the database") + + return cmd +} diff --git a/cmd/crowdsec-cli/clialert/sanitize.go b/cmd/crowdsec-cli/clialert/sanitize.go new file mode 100644 index 00000000000..87b110649da --- /dev/null +++ b/cmd/crowdsec-cli/clialert/sanitize.go @@ -0,0 +1,26 @@ +package clialert + +import ( + "fmt" + "net" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// SanitizeScope validates ip and range and sets the scope accordingly to our case convention. +func SanitizeScope(scope, ip, ipRange string) (string, error) { + if ipRange != "" { + _, _, err := net.ParseCIDR(ipRange) + if err != nil { + return "", fmt.Errorf("%s is not a valid range", ipRange) + } + } + + if ip != "" { + if net.ParseIP(ip) == nil { + return "", fmt.Errorf("%s is not a valid ip", ip) + } + } + + return types.NormalizeScope(scope), nil +} diff --git a/cmd/crowdsec-cli/alerts_table.go b/cmd/crowdsec-cli/clialert/table.go similarity index 80% rename from cmd/crowdsec-cli/alerts_table.go rename to cmd/crowdsec-cli/clialert/table.go index ec457f3723e..1416e1e435c 100644 --- a/cmd/crowdsec-cli/alerts_table.go +++ b/cmd/crowdsec-cli/clialert/table.go @@ -1,4 +1,4 @@ -package main +package clialert import ( "fmt" @@ -9,16 +9,19 @@ import ( log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { - t := newTable(out) +func alertsTable(out io.Writer, wantColor string, alerts *models.GetAlertsResponse, printMachine bool) { + t := cstable.New(out, wantColor) t.SetRowLines(false) + header := []string{"ID", "value", "reason", "country", "as", "decisions", "created_at"} if printMachine { header = append(header, "machine") } + t.SetHeaders(header...) for _, alertItem := range *alerts { @@ -35,7 +38,7 @@ func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine b *alertItem.Scenario, alertItem.Source.Cn, alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), + decisionsFromAlert(alertItem), *alertItem.StartAt, } @@ -49,25 +52,30 @@ func alertsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine b t.Render() } -func alertDecisionsTable(out io.Writer, alert *models.Alert) { +func alertDecisionsTable(out io.Writer, wantColor string, alert *models.Alert) { foundActive := false - t := newTable(out) + t := cstable.New(out, wantColor) t.SetRowLines(false) t.SetHeaders("ID", "scope:value", "action", "expiration", "created_at") + for _, decision := range alert.Decisions { parsedDuration, err := time.ParseDuration(*decision.Duration) if err != nil { log.Error(err) } + expire := time.Now().UTC().Add(parsedDuration) if time.Now().UTC().After(expire) { continue } + foundActive = true scopeAndValue := *decision.Scope + if *decision.Value != "" { scopeAndValue += ":" + *decision.Value } + t.AddRow( strconv.Itoa(int(decision.ID)), scopeAndValue, @@ -76,16 +84,17 @@ func alertDecisionsTable(out io.Writer, alert *models.Alert) { alert.CreatedAt, ) } + if foundActive { fmt.Printf(" - Active Decisions :\n") t.Render() // Send output } } -func alertEventTable(out io.Writer, event *models.Event) { +func alertEventTable(out io.Writer, wantColor string, event *models.Event) { fmt.Fprintf(out, "\n- Date: %s\n", *event.Timestamp) - t := newTable(out) + t := cstable.New(out, wantColor) t.SetHeaders("Key", "Value") sort.Slice(event.Meta, func(i, j int) bool { return event.Meta[i].Key < event.Meta[j].Key diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go new file mode 100644 index 00000000000..226fbb7e922 --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -0,0 +1,497 @@ +package clibouncer + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "slices" + "strings" + "time" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter = func() *csconfig.Config + +type cliBouncers struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliBouncers { + return &cliBouncers{ + cfg: cfg, + } +} + +func (cli *cliBouncers) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "bouncers [action]", + Short: "Manage bouncers [requires local API]", + Long: `To list/add/delete/prune bouncers. +Note: This command requires database direct access, so is intended to be run on Local API/master. +`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"bouncer"}, + DisableAutoGenTag: true, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +func (cli *cliBouncers) listHuman(out io.Writer, bouncers ent.Bouncers) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"}) + + for _, b := range bouncers { + revoked := emoji.CheckMark + if b.Revoked { + revoked = emoji.Prohibited + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + t.AppendRow(table.Row{b.Name, b.IPAddress, revoked, lastPull, b.Type, b.Version, b.AuthType}) + } + + io.WriteString(out, t.Render()+"\n") +} + +// bouncerInfo contains only the data we want for inspect/list +type bouncerInfo struct { + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Name string `json:"name"` + Revoked bool `json:"revoked"` + IPAddress string `json:"ip_address"` + Type string `json:"type"` + Version string `json:"version"` + LastPull *time.Time `json:"last_pull"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` +} + +func newBouncerInfo(b *ent.Bouncer) bouncerInfo { + return bouncerInfo{ + CreatedAt: b.CreatedAt, + UpdatedAt: b.UpdatedAt, + Name: b.Name, + Revoked: b.Revoked, + IPAddress: b.IPAddress, + Type: b.Type, + Version: b.Version, + LastPull: b.LastPull, + AuthType: b.AuthType, + OS: clientinfo.GetOSNameAndVersion(b), + Featureflags: clientinfo.GetFeatureFlagList(b), + } +} + +func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { + csvwriter := csv.NewWriter(out) + + if err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}); err != nil { + return fmt.Errorf("failed to write raw header: %w", err) + } + + for _, b := range bouncers { + valid := "validated" + if b.Revoked { + valid = "pending" + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{b.Name, b.IPAddress, valid, lastPull, b.Type, b.Version, b.AuthType}); err != nil { + return fmt.Errorf("failed to write raw: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + bouncers, err := db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, bouncers) + case "json": + info := make([]bouncerInfo, 0, len(bouncers)) + for _, b := range bouncers { + info = append(info, newBouncerInfo(b)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, bouncers) + } + + return nil +} + +func (cli *cliBouncers) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all bouncers within the database", + Example: `cscli bouncers list`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} + +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { + var err error + + keyLength := 32 + + if key == "" { + key, err = middlewares.GenerateAPIKey(keyLength) + if err != nil { + return fmt.Errorf("unable to generate api key: %w", err) + } + } + + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + if err != nil { + return fmt.Errorf("unable to create bouncer: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + fmt.Printf("API key for '%s':\n\n", bouncerName) + fmt.Printf(" %s\n\n", key) + fmt.Print("Please keep this key since you will not be able to retrieve it!\n") + case "raw": + fmt.Print(key) + case "json": + j, err := json.Marshal(key) + if err != nil { + return errors.New("unable to serialize api key") + } + + fmt.Print(string(j)) + } + + return nil +} + +func (cli *cliBouncers) newAddCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ + Use: "add MyBouncerName", + Short: "add a single bouncer to the database", + Example: `cscli bouncers add MyBouncerName +cscli bouncers add MyBouncerName --key `, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) + }, + } + + flags := cmd.Flags() + flags.StringP("length", "l", "", "length of the api key") + _ = flags.MarkDeprecated("length", "use --key instead") + flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") + + return cmd +} + +// validBouncerID returns a list of bouncer IDs for command completion +func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + bouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, bouncer := range bouncers { + if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { + ret = append(ret, bouncer.Name) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { + for _, bouncerID := range bouncers { + if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { + var notFoundErr *database.BouncerNotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + return nil + } + + return fmt.Errorf("unable to delete bouncer: %w", err) + } + + log.Infof("bouncer '%s' deleted successfully", bouncerID) + } + + return nil +} + +func (cli *cliBouncers) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete MyBouncerName", + Short: "delete bouncer(s) from the database", + Example: `cscli bouncers delete "bouncer1" "bouncer2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more bouncers don't exist") + + return cmd +} + +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { + if duration < 2*time.Minute { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This may remove active bouncers. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration)) + if err != nil { + return fmt.Errorf("unable to query bouncers: %w", err) + } + + if len(bouncers) == 0 { + fmt.Println("No bouncers to prune.") + return nil + } + + cli.listHuman(color.Output, bouncers) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above bouncers from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) + if err != nil { + return fmt.Errorf("unable to prune bouncers: %w", err) + } + + fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) + + return nil +} + +func (cli *cliBouncers) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + force bool + ) + + const defaultDuration = 60 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple bouncers from the database", + Args: cobra.NoArgs, + DisableAutoGenTag: true, + Example: `cscli bouncers prune -d 45m +cscli bouncers prune -d 45m --force`, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since last pull") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} + +func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Bouncer: " + bouncer.Name) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + lastPull := "" + if bouncer.LastPull != nil { + lastPull = bouncer.LastPull.String() + } + + t.AppendRows([]table.Row{ + {"Created At", bouncer.CreatedAt}, + {"Last Update", bouncer.UpdatedAt}, + {"Revoked?", bouncer.Revoked}, + {"IP Address", bouncer.IPAddress}, + {"Type", bouncer.Type}, + {"Version", bouncer.Version}, + {"Last Pull", lastPull}, + {"Auth type", bouncer.AuthType}, + {"OS", clientinfo.GetOSNameAndVersion(bouncer)}, + }) + + for _, ff := range clientinfo.GetFeatureFlagList(bouncer) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliBouncers) inspect(bouncer *ent.Bouncer) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, bouncer) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newBouncerInfo(bouncer)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliBouncers) newInspectCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inspect [bouncer_name]", + Short: "inspect a bouncer by name", + Example: `cscli bouncers inspect "bouncer1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + bouncerName := args[0] + + b, err := cli.db.Ent.Bouncer.Query(). + Where(bouncer.Name(bouncerName)). + Only(cmd.Context()) + if err != nil { + return fmt.Errorf("unable to read bouncer data '%s': %w", bouncerName, err) + } + + return cli.inspect(b) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clicapi/capi.go b/cmd/crowdsec-cli/clicapi/capi.go new file mode 100644 index 00000000000..cba66f11104 --- /dev/null +++ b/cmd/crowdsec-cli/clicapi/capi.go @@ -0,0 +1,248 @@ +package clicapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/url" + "os" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter = func() *csconfig.Config + +type cliCapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliCapi { + return &cliCapi{ + cfg: cfg, + } +} + +func (cli *cliCapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "capi [action]", + Short: "Manage interaction with Central API (CAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + + return require.CAPI(cfg) + }, + } + + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliCapi) register(ctx context.Context, capiUserPrefix string, outputFile string) error { + cfg := cli.cfg() + + capiUser, err := idgen.GenerateMachineID(capiUserPrefix) + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) + + apiurl, err := url.Parse(types.CAPIBaseURL) + if err != nil { + return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) + } + + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ + MachineID: capiUser, + Password: password, + URL: apiurl, + VersionPrefix: "v3", + }, nil) + if err != nil { + return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) + } + + log.Infof("Successfully registered to Central API (CAPI)") + + var dumpFile string + + switch { + case outputFile != "": + dumpFile = outputFile + case cfg.API.Server.OnlineClient.CredentialsFilePath != "": + dumpFile = cfg.API.Server.OnlineClient.CredentialsFilePath + default: + dumpFile = "" + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: capiUser, + Password: password.String(), + URL: types.CAPIBaseURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + log.Infof("Central API credentials written to '%s'", dumpFile) + } else { + fmt.Println(string(apiConfigDump)) + } + + log.Warning(reload.Message) + + return nil +} + +func (cli *cliCapi) newRegisterCmd() *cobra.Command { + var ( + capiUserPrefix string + outputFile string + ) + + cmd := &cobra.Command{ + Use: "register", + Short: "Register to Central API (CAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), capiUserPrefix, outputFile) + }, + } + + cmd.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination") + cmd.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)") + + _ = cmd.Flags().MarkHidden("schmilblick") + + return cmd +} + +// queryCAPIStatus checks if the Central API is reachable, and if the credentials are correct. It then checks if the instance is enrolle in the console. +func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) { + apiURL, err := url.Parse(credURL) + if err != nil { + return false, false, err + } + + itemsForAPI := hub.GetInstalledListForAPI() + + if len(itemsForAPI) == 0 { + return false, false, errors.New("no scenarios or appsec-rules installed, abort") + } + + passwd := strfmt.Password(password) + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: login, + Password: passwd, + Scenarios: itemsForAPI, + URL: apiURL, + // I don't believe papi is neede to check enrollement + // PapiURL: papiURL, + VersionPrefix: "v3", + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil + }, + }) + if err != nil { + return false, false, err + } + + pw := strfmt.Password(password) + + t := models.WatcherAuthRequest{ + MachineID: &login, + Password: &pw, + Scenarios: itemsForAPI, + } + + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, t) + if err != nil { + return false, false, err + } + + client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + + if client.IsEnrolled() { + return true, true, nil + } + + return true, false, nil +} + +func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { + cfg := cli.cfg() + + if err := require.CAPIRegistered(cfg); err != nil { + return err + } + + cred := cfg.API.Server.OnlineClient.Credentials + + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) + + auth, enrolled, err := queryCAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) + if err != nil { + return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) + } + + if auth { + fmt.Fprint(out, "You can successfully interact with Central API (CAPI)\n") + } + + if enrolled { + fmt.Fprint(out, "Your instance is enrolled in the console\n") + } + + return nil +} + +func (cli *cliCapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Check status with the Central API (CAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/cliconsole/console.go b/cmd/crowdsec-cli/cliconsole/console.go new file mode 100644 index 00000000000..448ddcee7fa --- /dev/null +++ b/cmd/crowdsec-cli/cliconsole/console.go @@ -0,0 +1,417 @@ +package cliconsole + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "strings" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type configGetter func() *csconfig.Config + +type cliConsole struct { + cfg configGetter +} + +func New(cfg configGetter) *cliConsole { + return &cliConsole{ + cfg: cfg, + } +} + +func (cli *cliConsole) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "console [action]", + Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := require.CAPI(cfg); err != nil { + return err + } + + return require.CAPIRegistered(cfg) + }, + } + + cmd.AddCommand(cli.newEnrollCmd()) + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error { + cfg := cli.cfg() + password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) + + apiURL, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL) + if err != nil { + return fmt.Errorf("could not parse CAPI URL: %w", err) + } + + enableOpts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} + + if len(opts) != 0 { + for _, opt := range opts { + valid := false + + if opt == "all" { + enableOpts = csconfig.CONSOLE_CONFIGS + break + } + + for _, availableOpt := range csconfig.CONSOLE_CONFIGS { + if opt != availableOpt { + continue + } + + valid = true + enable := true + + for _, enabledOpt := range enableOpts { + if opt == enabledOpt { + enable = false + continue + } + } + + if enable { + enableOpts = append(enableOpts, opt) + } + + break + } + + if !valid { + return fmt.Errorf("option %s doesn't exist", opt) + } + } + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + + c, _ := apiclient.NewClient(&apiclient.Config{ + MachineID: cli.cfg().API.Server.OnlineClient.Credentials.Login, + Password: password, + Scenarios: hub.GetInstalledListForAPI(), + URL: apiURL, + VersionPrefix: "v3", + }) + + resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite) + if err != nil { + return fmt.Errorf("could not enroll instance: %w", err) + } + + if resp.Response.StatusCode == http.StatusOK && !overwrite { + log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll") + return nil + } + + if err := cli.setConsoleOpts(enableOpts, true); err != nil { + return err + } + + for _, opt := range enableOpts { + log.Infof("Enabled %s : %s", opt, csconfig.CONSOLE_CONFIGS_HELP[opt]) + } + + log.Info("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.") + log.Info("Please restart crowdsec after accepting the enrollment.") + + return nil +} + +func (cli *cliConsole) newEnrollCmd() *cobra.Command { + name := "" + overwrite := false + tags := []string{} + opts := []string{} + + cmd := &cobra.Command{ + Use: "enroll [enroll-key]", + Short: "Enroll this instance to https://app.crowdsec.net [requires local API]", + Long: ` +Enroll this instance to https://app.crowdsec.net + +You can get your enrollment key by creating an account on https://app.crowdsec.net. +After running this command your will need to validate the enrollment in the webapp.`, + Example: fmt.Sprintf(`cscli console enroll YOUR-ENROLL-KEY + cscli console enroll --name [instance_name] YOUR-ENROLL-KEY + cscli console enroll --name [instance_name] --tags [tag_1] --tags [tag_2] YOUR-ENROLL-KEY + cscli console enroll --enable context,manual YOUR-ENROLL-KEY + + valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&name, "name", "n", "", "Name to display in the console") + flags.BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") + flags.StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") + flags.StringSliceVarP(&opts, "enable", "e", opts, "Enable console options") + + return cmd +} + +func (cli *cliConsole) newEnableCmd() *cobra.Command { + var enableAll bool + + cmd := &cobra.Command{ + Use: "enable [option]", + Short: "Enable a console option", + Example: "sudo cscli console enable tainted", + Long: ` +Enable given information push to the central API. Allows to empower the console`, + ValidArgs: csconfig.CONSOLE_CONFIGS, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + if enableAll { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil { + return err + } + log.Infof("All features have been enabled successfully") + } else { + if len(args) == 0 { + return errors.New("you must specify at least one feature to enable") + } + if err := cli.setConsoleOpts(args, true); err != nil { + return err + } + log.Infof("%v have been enabled", args) + } + + log.Info(reload.Message) + + return nil + }, + } + cmd.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") + + return cmd +} + +func (cli *cliConsole) newDisableCmd() *cobra.Command { + var disableAll bool + + cmd := &cobra.Command{ + Use: "disable [option]", + Short: "Disable a console option", + Example: "sudo cscli console disable tainted", + Long: ` +Disable given information push to the central API.`, + ValidArgs: csconfig.CONSOLE_CONFIGS, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + if disableAll { + if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil { + return err + } + log.Infof("All features have been disabled") + } else { + if err := cli.setConsoleOpts(args, false); err != nil { + return err + } + log.Infof("%v have been disabled", args) + } + + log.Info(reload.Message) + + return nil + }, + } + cmd.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") + + return cmd +} + +func (cli *cliConsole) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Shows status of the console options", + Example: `sudo cscli console status`, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + switch cfg.Cscli.Output { + case "human": + cmdConsoleStatusTable(color.Output, cfg.Cscli.Color, *consoleCfg) + case "json": + out := map[string](*bool){ + csconfig.SEND_MANUAL_SCENARIOS: consoleCfg.ShareManualDecisions, + csconfig.SEND_CUSTOM_SCENARIOS: consoleCfg.ShareCustomScenarios, + csconfig.SEND_TAINTED_SCENARIOS: consoleCfg.ShareTaintedScenarios, + csconfig.SEND_CONTEXT: consoleCfg.ShareContext, + csconfig.CONSOLE_MANAGEMENT: consoleCfg.ConsoleManagement, + } + data, err := json.MarshalIndent(out, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize configuration: %w", err) + } + fmt.Println(string(data)) + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + err := csvwriter.Write([]string{"option", "enabled"}) + if err != nil { + return err + } + + rows := [][]string{ + {csconfig.SEND_MANUAL_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareManualDecisions)}, + {csconfig.SEND_CUSTOM_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareCustomScenarios)}, + {csconfig.SEND_TAINTED_SCENARIOS, strconv.FormatBool(*consoleCfg.ShareTaintedScenarios)}, + {csconfig.SEND_CONTEXT, strconv.FormatBool(*consoleCfg.ShareContext)}, + {csconfig.CONSOLE_MANAGEMENT, strconv.FormatBool(*consoleCfg.ConsoleManagement)}, + } + for _, row := range rows { + err = csvwriter.Write(row) + if err != nil { + return err + } + } + csvwriter.Flush() + } + + return nil + }, + } + + return cmd +} + +func (cli *cliConsole) dumpConfig() error { + serverCfg := cli.cfg().API.Server + + out, err := yaml.Marshal(serverCfg.ConsoleConfig) + if err != nil { + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err) + } + + if serverCfg.ConsoleConfigPath == "" { + serverCfg.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath + log.Debugf("Empty console_path, defaulting to %s", serverCfg.ConsoleConfigPath) + } + + if err := os.WriteFile(serverCfg.ConsoleConfigPath, out, 0o600); err != nil { + return fmt.Errorf("while dumping console config to %s: %w", serverCfg.ConsoleConfigPath, err) + } + + return nil +} + +func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { + cfg := cli.cfg() + consoleCfg := cfg.API.Server.ConsoleConfig + + for _, arg := range args { + switch arg { + case csconfig.CONSOLE_MANAGEMENT: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ConsoleManagement != nil && *consoleCfg.ConsoleManagement == wanted { + log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) + } else { + log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) + consoleCfg.ConsoleManagement = ptr.Of(wanted) + } + + if cfg.API.Server.OnlineClient.Credentials != nil { + changed := false + if wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL == "" { + changed = true + cfg.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL + } else if !wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL != "" { + changed = true + cfg.API.Server.OnlineClient.Credentials.PapiURL = "" + } + + if changed { + fileContent, err := yaml.Marshal(cfg.API.Server.OnlineClient.Credentials) + if err != nil { + return fmt.Errorf("cannot serialize credentials: %w", err) + } + + log.Infof("Updating credentials file: %s", cfg.API.Server.OnlineClient.CredentialsFilePath) + + err = os.WriteFile(cfg.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600) + if err != nil { + return fmt.Errorf("cannot write credentials file: %w", err) + } + } + } + case csconfig.SEND_CUSTOM_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareCustomScenarios != nil && *consoleCfg.ShareCustomScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) + consoleCfg.ShareCustomScenarios = ptr.Of(wanted) + } + case csconfig.SEND_TAINTED_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareTaintedScenarios != nil && *consoleCfg.ShareTaintedScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) + consoleCfg.ShareTaintedScenarios = ptr.Of(wanted) + } + case csconfig.SEND_MANUAL_SCENARIOS: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareManualDecisions != nil && *consoleCfg.ShareManualDecisions == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) + consoleCfg.ShareManualDecisions = ptr.Of(wanted) + } + case csconfig.SEND_CONTEXT: + /*for each flag check if it's already set before setting it*/ + if consoleCfg.ShareContext != nil && *consoleCfg.ShareContext == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) + } else { + log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) + consoleCfg.ShareContext = ptr.Of(wanted) + } + default: + return fmt.Errorf("unknown flag %s", arg) + } + } + + if err := cli.dumpConfig(); err != nil { + return fmt.Errorf("failed writing console config: %w", err) + } + + return nil +} diff --git a/cmd/crowdsec-cli/cliconsole/console_table.go b/cmd/crowdsec-cli/cliconsole/console_table.go new file mode 100644 index 00000000000..8f17b97860a --- /dev/null +++ b/cmd/crowdsec-cli/cliconsole/console_table.go @@ -0,0 +1,50 @@ +package cliconsole + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func cmdConsoleStatusTable(out io.Writer, wantColor string, consoleCfg csconfig.ConsoleConfig) { + t := cstable.New(out, wantColor) + t.SetRowLines(false) + + t.SetHeaders("Option Name", "Activated", "Description") + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + + for _, option := range csconfig.CONSOLE_CONFIGS { + activated := emoji.CrossMark + + switch option { + case csconfig.SEND_CUSTOM_SCENARIOS: + if *consoleCfg.ShareCustomScenarios { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_MANUAL_SCENARIOS: + if *consoleCfg.ShareManualDecisions { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_TAINTED_SCENARIOS: + if *consoleCfg.ShareTaintedScenarios { + activated = emoji.CheckMarkButton + } + case csconfig.SEND_CONTEXT: + if *consoleCfg.ShareContext { + activated = emoji.CheckMarkButton + } + case csconfig.CONSOLE_MANAGEMENT: + if *consoleCfg.ConsoleManagement { + activated = emoji.CheckMarkButton + } + } + + t.AddRow(option, activated, csconfig.CONSOLE_CONFIGS_HELP[option]) + } + + t.Render() +} diff --git a/cmd/crowdsec-cli/clidecision/decisions.go b/cmd/crowdsec-cli/clidecision/decisions.go new file mode 100644 index 00000000000..1f8781a3716 --- /dev/null +++ b/cmd/crowdsec-cli/clidecision/decisions.go @@ -0,0 +1,565 @@ +package clidecision + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (cli *cliDecisions) decisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { + /*here we cheat a bit : to make it more readable for the user, we dedup some entries*/ + spamLimit := make(map[string]bool) + skipped := 0 + + for aIdx := range len(*alerts) { + alertItem := (*alerts)[aIdx] + newDecisions := make([]*models.Decision, 0) + + for _, decisionItem := range alertItem.Decisions { + spamKey := fmt.Sprintf("%t:%s:%s:%s", *decisionItem.Simulated, *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) + if _, ok := spamLimit[spamKey]; ok { + skipped++ + continue + } + + spamLimit[spamKey] = true + + newDecisions = append(newDecisions, decisionItem) + } + + alertItem.Decisions = newDecisions + } + + switch cli.cfg().Cscli.Output { + case "raw": + csvwriter := csv.NewWriter(os.Stdout) + header := []string{"id", "source", "ip", "reason", "action", "country", "as", "events_count", "expiration", "simulated", "alert_id"} + + if printMachine { + header = append(header, "machine") + } + + err := csvwriter.Write(header) + if err != nil { + return err + } + + for _, alertItem := range *alerts { + for _, decisionItem := range alertItem.Decisions { + raw := []string{ + fmt.Sprintf("%d", decisionItem.ID), + *decisionItem.Origin, + *decisionItem.Scope + ":" + *decisionItem.Value, + *decisionItem.Scenario, + *decisionItem.Type, + alertItem.Source.Cn, + alertItem.Source.GetAsNumberName(), + fmt.Sprintf("%d", *alertItem.EventsCount), + *decisionItem.Duration, + fmt.Sprintf("%t", *decisionItem.Simulated), + fmt.Sprintf("%d", alertItem.ID), + } + if printMachine { + raw = append(raw, alertItem.MachineID) + } + + err := csvwriter.Write(raw) + if err != nil { + return err + } + } + } + + csvwriter.Flush() + case "json": + if *alerts == nil { + // avoid returning "null" in `json" + // could be cleaner if we used slice of alerts directly + fmt.Println("[]") + return nil + } + + x, _ := json.MarshalIndent(alerts, "", " ") + fmt.Printf("%s", string(x)) + case "human": + if len(*alerts) == 0 { + fmt.Println("No active decisions") + return nil + } + + cli.decisionsTable(color.Output, alerts, printMachine) + + if skipped > 0 { + fmt.Printf("%d duplicated entries skipped\n", skipped) + } + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliDecisions struct { + client *apiclient.ApiClient + cfg configGetter +} + +func New(cfg configGetter) *cliDecisions { + return &cliDecisions{ + cfg: cfg, + } +} + +func (cli *cliDecisions) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "decisions [action]", + Short: "Manage decisions", + Long: `Add/List/Delete/Import decisions from LAPI`, + Example: `cscli decisions [action] [filter]`, + Aliases: []string{"decision"}, + /*TBD example*/ + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newImportCmd()) + + return cmd +} + +func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { + var err error + + *filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals) + if err != nil { + return err + } + + filter.ActiveDecisionEquals = new(bool) + *filter.ActiveDecisionEquals = true + + if NoSimu != nil && *NoSimu { + filter.IncludeSimulated = new(bool) + } + /* nullify the empty entries to avoid bad filter */ + if *filter.Until == "" { + filter.Until = nil + } else if strings.HasSuffix(*filter.Until, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*filter.Until, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Until) + } + + *filter.Until = fmt.Sprintf("%d%s", days*24, "h") + } + + if *filter.Since == "" { + filter.Since = nil + } else if strings.HasSuffix(*filter.Since, "d") { + /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ + realDuration := strings.TrimSuffix(*filter.Since, "d") + + days, err := strconv.Atoi(realDuration) + if err != nil { + return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Since) + } + + *filter.Since = fmt.Sprintf("%d%s", days*24, "h") + } + + if *filter.IncludeCAPI { + *filter.Limit = 0 + } + + if *filter.TypeEquals == "" { + filter.TypeEquals = nil + } + + if *filter.ValueEquals == "" { + filter.ValueEquals = nil + } + + if *filter.ScopeEquals == "" { + filter.ScopeEquals = nil + } + + if *filter.ScenarioEquals == "" { + filter.ScenarioEquals = nil + } + + if *filter.IPEquals == "" { + filter.IPEquals = nil + } + + if *filter.RangeEquals == "" { + filter.RangeEquals = nil + } + + if *filter.OriginEquals == "" { + filter.OriginEquals = nil + } + + if contained != nil && *contained { + filter.Contains = new(bool) + } + + alerts, _, err := cli.client.Alerts.List(ctx, filter) + if err != nil { + return fmt.Errorf("unable to retrieve decisions: %w", err) + } + + err = cli.decisionsToTable(alerts, printMachine) + if err != nil { + return fmt.Errorf("unable to print decisions: %w", err) + } + + return nil +} + +func (cli *cliDecisions) newListCmd() *cobra.Command { + filter := apiclient.AlertsListOpts{ + ValueEquals: new(string), + ScopeEquals: new(string), + ScenarioEquals: new(string), + OriginEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + Since: new(string), + Until: new(string), + TypeEquals: new(string), + IncludeCAPI: new(bool), + Limit: new(int), + } + + NoSimu := new(bool) + contained := new(bool) + + var printMachine bool + + cmd := &cobra.Command{ + Use: "list [options]", + Short: "List decisions from LAPI", + Example: `cscli decisions list -i 1.2.3.4 +cscli decisions list -r 1.2.3.0/24 +cscli decisions list -s crowdsecurity/ssh-bf +cscli decisions list --origin lists --scenario list_name +`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.BoolVarP(filter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") + flags.StringVar(filter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") + flags.StringVar(filter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") + flags.StringVarP(filter.TypeEquals, "type", "t", "", "restrict to this decision type (ie. ban,captcha)") + flags.StringVar(filter.ScopeEquals, "scope", "", "restrict to this scope (ie. ip,range,session)") + flags.StringVar(filter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + flags.StringVarP(filter.ValueEquals, "value", "v", "", "restrict to this value (ie. 1.2.3.4,userName)") + flags.StringVarP(filter.ScenarioEquals, "scenario", "s", "", "restrict to this scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(filter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") + flags.StringVarP(filter.RangeEquals, "range", "r", "", "restrict to alerts from this source range (shorthand for --scope range --value )") + flags.IntVarP(filter.Limit, "limit", "l", 100, "number of alerts to get (use 0 to remove the limit)") + flags.BoolVar(NoSimu, "no-simu", false, "exclude decisions in simulation mode") + flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that triggered decisions") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} + +func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { + alerts := models.AddAlertsRequest{} + origin := types.CscliOrigin + capacity := int32(0) + leakSpeed := "0" + eventsCount := int32(1) + empty := "" + simulated := false + startAt := time.Now().UTC().Format(time.RFC3339) + stopAt := time.Now().UTC().Format(time.RFC3339) + createdAt := time.Now().UTC().Format(time.RFC3339) + + var err error + + addScope, err = clialert.SanitizeScope(addScope, addIP, addRange) + if err != nil { + return err + } + + if addIP != "" { + addValue = addIP + addScope = types.Ip + } else if addRange != "" { + addValue = addRange + addScope = types.Range + } else if addValue == "" { + return errors.New("missing arguments, a value is required (--ip, --range or --scope and --value)") + } + + if addReason == "" { + addReason = fmt.Sprintf("manual '%s' from '%s'", addType, cli.cfg().API.Client.Credentials.Login) + } + + decision := models.Decision{ + Duration: &addDuration, + Scope: &addScope, + Value: &addValue, + Type: &addType, + Scenario: &addReason, + Origin: &origin, + } + alert := models.Alert{ + Capacity: &capacity, + Decisions: []*models.Decision{&decision}, + Events: []*models.Event{}, + EventsCount: &eventsCount, + Leakspeed: &leakSpeed, + Message: &addReason, + ScenarioHash: &empty, + Scenario: &addReason, + ScenarioVersion: &empty, + Simulated: &simulated, + // setting empty scope/value broke plugins, and it didn't seem to be needed anymore w/ latest papi changes + Source: &models.Source{ + AsName: "", + AsNumber: "", + Cn: "", + IP: addValue, + Range: "", + Scope: &addScope, + Value: &addValue, + }, + StartAt: &startAt, + StopAt: &stopAt, + CreatedAt: createdAt, + Remediation: true, + } + alerts = append(alerts, &alert) + + _, _, err = cli.client.Alerts.Add(ctx, alerts) + if err != nil { + return err + } + + log.Info("Decision successfully added") + + return nil +} + +func (cli *cliDecisions) newAddCmd() *cobra.Command { + var ( + addIP string + addRange string + addDuration string + addValue string + addScope string + addReason string + addType string + ) + + cmd := &cobra.Command{ + Use: "add [options]", + Short: "Add decision to LAPI", + Example: `cscli decisions add --ip 1.2.3.4 +cscli decisions add --range 1.2.3.0/24 +cscli decisions add --ip 1.2.3.4 --duration 24h --type captcha +cscli decisions add --scope username --value foobar +`, + /*TBD : fix long and example*/ + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVarP(&addIP, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(&addRange, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVarP(&addDuration, "duration", "d", "4h", "Decision duration (ie. 1h,4h,30m)") + flags.StringVarP(&addValue, "value", "v", "", "The value (ie. --scope username --value foobar)") + flags.StringVar(&addScope, "scope", types.Ip, "Decision scope (ie. ip,range,username)") + flags.StringVarP(&addReason, "reason", "R", "", "Decision reason (ie. scenario-name)") + flags.StringVarP(&addType, "type", "t", "ban", "Decision type (ie. ban,captcha,throttle)") + + return cmd +} + +func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { + var err error + + /*take care of shorthand options*/ + *delFilter.ScopeEquals, err = clialert.SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { + return err + } + + if *delFilter.ScopeEquals == "" { + delFilter.ScopeEquals = nil + } + + if *delFilter.OriginEquals == "" { + delFilter.OriginEquals = nil + } + + if *delFilter.ValueEquals == "" { + delFilter.ValueEquals = nil + } + + if *delFilter.ScenarioEquals == "" { + delFilter.ScenarioEquals = nil + } + + if *delFilter.TypeEquals == "" { + delFilter.TypeEquals = nil + } + + if *delFilter.IPEquals == "" { + delFilter.IPEquals = nil + } + + if *delFilter.RangeEquals == "" { + delFilter.RangeEquals = nil + } + + if contained != nil && *contained { + delFilter.Contains = new(bool) + } + + var decisions *models.DeleteDecisionResponse + + if delDecisionID == "" { + decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter) + if err != nil { + return fmt.Errorf("unable to delete decisions: %w", err) + } + } else { + if _, err = strconv.Atoi(delDecisionID); err != nil { + return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err) + } + + decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID) + if err != nil { + return fmt.Errorf("unable to delete decision: %w", err) + } + } + + log.Infof("%s decision(s) deleted", decisions.NbDeleted) + + return nil +} + +func (cli *cliDecisions) newDeleteCmd() *cobra.Command { + delFilter := apiclient.DecisionsDeleteOpts{ + ScopeEquals: new(string), + ValueEquals: new(string), + TypeEquals: new(string), + IPEquals: new(string), + RangeEquals: new(string), + ScenarioEquals: new(string), + OriginEquals: new(string), + } + + var delDecisionID string + + var delDecisionAll bool + + contained := new(bool) + + cmd := &cobra.Command{ + Use: "delete [options]", + Short: "Delete decisions", + DisableAutoGenTag: true, + Aliases: []string{"remove"}, + Example: `cscli decisions delete -r 1.2.3.0/24 +cscli decisions delete -i 1.2.3.4 +cscli decisions delete --id 42 +cscli decisions delete --type captcha +cscli decisions delete --origin lists --scenario list_name +`, + /*TBD : refaire le Long/Example*/ + PreRunE: func(cmd *cobra.Command, _ []string) error { + if delDecisionAll { + return nil + } + if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && + *delFilter.TypeEquals == "" && *delFilter.IPEquals == "" && + *delFilter.RangeEquals == "" && *delFilter.ScenarioEquals == "" && + *delFilter.OriginEquals == "" && delDecisionID == "" { + _ = cmd.Usage() + return errors.New("at least one filter or --all must be specified") + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, delDecisionID, contained) + }, + } + + flags := cmd.Flags() + flags.SortFlags = false + flags.StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVarP(delFilter.TypeEquals, "type", "t", "", "the decision type (ie. ban,captcha)") + flags.StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario name (ie. crowdsecurity/ssh-bf)") + flags.StringVar(delFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) + + flags.StringVar(&delDecisionID, "id", "", "decision id") + flags.BoolVar(&delDecisionAll, "all", false, "delete all decisions") + flags.BoolVar(contained, "contained", false, "query decisions contained by range") + + return cmd +} diff --git a/cmd/crowdsec-cli/decisions_import.go b/cmd/crowdsec-cli/clidecision/decisions_import.go similarity index 88% rename from cmd/crowdsec-cli/decisions_import.go rename to cmd/crowdsec-cli/clidecision/decisions_import.go index 2d7ee485bd1..10d92f88876 100644 --- a/cmd/crowdsec-cli/decisions_import.go +++ b/cmd/crowdsec-cli/clidecision/decisions_import.go @@ -1,10 +1,11 @@ -package main +package clidecision import ( "bufio" "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -45,7 +46,7 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { } if err := scanner.Err(); err != nil { - return nil, fmt.Errorf("unable to parse values: '%s'", err) + return nil, fmt.Errorf("unable to parse values: '%w'", err) } case "json": log.Infof("Parsing json") @@ -57,7 +58,7 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { log.Infof("Parsing csv") if err := csvutil.Unmarshal(content, &ret); err != nil { - return nil, fmt.Errorf("unable to parse csv: '%s'", err) + return nil, fmt.Errorf("unable to parse csv: '%w'", err) } default: return nil, fmt.Errorf("invalid format '%s', expected one of 'json', 'csv', 'values'", format) @@ -66,8 +67,7 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { return ret, nil } - -func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { +func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { flags := cmd.Flags() input, err := flags.GetString("input") @@ -81,7 +81,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if defaultDuration == "" { - return fmt.Errorf("--duration cannot be empty") + return errors.New("--duration cannot be empty") } defaultScope, err := flags.GetString("scope") @@ -90,7 +90,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if defaultScope == "" { - return fmt.Errorf("--scope cannot be empty") + return errors.New("--scope cannot be empty") } defaultReason, err := flags.GetString("reason") @@ -99,7 +99,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if defaultReason == "" { - return fmt.Errorf("--reason cannot be empty") + return errors.New("--reason cannot be empty") } defaultType, err := flags.GetString("type") @@ -108,7 +108,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if defaultType == "" { - return fmt.Errorf("--type cannot be empty") + return errors.New("--type cannot be empty") } batchSize, err := flags.GetInt("batch") @@ -123,7 +123,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { var ( content []byte - fin *os.File + fin *os.File ) // set format if the file has a json or csv extension @@ -136,7 +136,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if format == "" { - return fmt.Errorf("unable to guess format from file extension, please provide a format with --format flag") + return errors.New("unable to guess format from file extension, please provide a format with --format flag") } if input == "-" { @@ -145,13 +145,13 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { } else { fin, err = os.Open(input) if err != nil { - return fmt.Errorf("unable to open %s: %s", input, err) + return fmt.Errorf("unable to open %s: %w", input, err) } } content, err = io.ReadAll(fin) if err != nil { - return fmt.Errorf("unable to read from %s: %s", input, err) + return fmt.Errorf("unable to read from %s: %w", input, err) } decisionsListRaw, err := parseDecisionList(content, format) @@ -224,7 +224,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { Decisions: chunk, } - _, _, err = Client.Alerts.Add(context.Background(), models.AddAlertsRequest{&importAlert}) + _, _, err = cli.client.Alerts.Add(context.Background(), models.AddAlertsRequest{&importAlert}) if err != nil { return err } @@ -235,15 +235,14 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error { return nil } - -func (cli cliDecisions) NewImportCmd() *cobra.Command { +func (cli *cliDecisions) newImportCmd() *cobra.Command { cmd := &cobra.Command{ Use: "import [options]", Short: "Import decisions from a file or pipe", Long: "expected format:\n" + "csv : any of duration,reason,scope,type,value, with a header line\n" + "json :" + "`{" + `"duration" : "24h", "reason" : "my_scenario", "scope" : "ip", "type" : "ban", "value" : "x.y.z.z"` + "}`", - Args: cobra.NoArgs, + Args: cobra.NoArgs, DisableAutoGenTag: true, Example: `decisions.csv: duration,scope,value @@ -274,7 +273,7 @@ $ echo "1.2.3.4" | cscli decisions import -i - --format values flags.Int("batch", 0, "Split import in batches of N decisions") flags.String("format", "", "Input format: 'json', 'csv' or 'values' (each line is a value, no headers)") - cmd.MarkFlagRequired("input") + _ = cmd.MarkFlagRequired("input") return cmd } diff --git a/cmd/crowdsec-cli/decisions_table.go b/cmd/crowdsec-cli/clidecision/decisions_table.go similarity index 80% rename from cmd/crowdsec-cli/decisions_table.go rename to cmd/crowdsec-cli/clidecision/decisions_table.go index d8d5e032594..90a0ae1176b 100644 --- a/cmd/crowdsec-cli/decisions_table.go +++ b/cmd/crowdsec-cli/clidecision/decisions_table.go @@ -1,20 +1,23 @@ -package main +package clidecision import ( "fmt" "io" "strconv" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { - t := newTable(out) +func (cli *cliDecisions) decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) { + t := cstable.New(out, cli.cfg().Cscli.Color) t.SetRowLines(false) + header := []string{"ID", "Source", "Scope:Value", "Reason", "Action", "Country", "AS", "Events", "expiration", "Alert ID"} if printMachine { header = append(header, "Machine") } + t.SetHeaders(header...) for _, alertItem := range *alerts { @@ -22,6 +25,7 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin if *alertItem.Simulated { *decisionItem.Type = fmt.Sprintf("(simul)%s", *decisionItem.Type) } + row := []string{ strconv.Itoa(int(decisionItem.ID)), *decisionItem.Origin, @@ -42,5 +46,6 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin t.AddRow(row...) } } + t.Render() } diff --git a/cmd/crowdsec-cli/clientinfo/clientinfo.go b/cmd/crowdsec-cli/clientinfo/clientinfo.go new file mode 100644 index 00000000000..0bf1d98804f --- /dev/null +++ b/cmd/crowdsec-cli/clientinfo/clientinfo.go @@ -0,0 +1,39 @@ +package clientinfo + +import ( + "strings" +) + +type featureflagProvider interface { + GetFeatureflags() string +} + +type osProvider interface { + GetOsname() string + GetOsversion() string +} + +func GetOSNameAndVersion(o osProvider) string { + ret := o.GetOsname() + if o.GetOsversion() != "" { + if ret != "" { + ret += "/" + } + + ret += o.GetOsversion() + } + + if ret == "" { + return "?" + } + + return ret +} + +func GetFeatureFlagList(o featureflagProvider) []string { + if o.GetFeatureflags() == "" { + return nil + } + + return strings.Split(o.GetFeatureflags(), ",") +} diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/cliexplain/explain.go similarity index 53% rename from cmd/crowdsec-cli/explain.go rename to cmd/crowdsec-cli/cliexplain/explain.go index d21c1704930..182e34a12a5 100644 --- a/cmd/crowdsec-cli/explain.go +++ b/cmd/crowdsec-cli/cliexplain/explain.go @@ -1,4 +1,4 @@ -package main +package cliexplain import ( "bufio" @@ -12,37 +12,62 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) -func GetLineCountForFile(filepath string) (int, error) { +func getLineCountForFile(filepath string) (int, error) { f, err := os.Open(filepath) if err != nil { return 0, err } defer f.Close() + lc := 0 fs := bufio.NewReader(f) + for { input, err := fs.ReadBytes('\n') if len(input) > 1 { lc++ } + if err != nil && err == io.EOF { break } } + return lc, nil } -type cliExplain struct{} +type configGetter func() *csconfig.Config + +type cliExplain struct { + cfg configGetter + configFilePath string + flags struct { + logFile string + dsn string + logLine string + logType string + details bool + skipOk bool + onlySuccessfulParsers bool + noClean bool + crowdsec string + labels string + } +} -func NewCLIExplain() *cliExplain { - return &cliExplain{} +func New(cfg configGetter, configFilePath string) *cliExplain { + return &cliExplain{ + cfg: cfg, + configFilePath: configFilePath, + } } -func (cli cliExplain) NewCommand() *cobra.Command { +func (cli *cliExplain) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "explain", Short: "Explain log pipeline", @@ -57,118 +82,50 @@ tail -n 5 myfile.log | cscli explain --type nginx -f - `, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.run, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - logFile, err := flags.GetString("file") - if err != nil { - return err - } - - dsn, err := flags.GetString("dsn") - if err != nil { - return err - } - - logLine, err := flags.GetString("log") - if err != nil { - return err - } - - logType, err := flags.GetString("type") - if err != nil { - return err - } - - if logLine == "" && logFile == "" && dsn == "" { - printHelp(cmd) - fmt.Println() - return fmt.Errorf("please provide --log, --file or --dsn flag") - } - if logType == "" { - printHelp(cmd) - fmt.Println() - return fmt.Errorf("please provide --type flag") - } + RunE: func(_ *cobra.Command, _ []string) error { + return cli.run() + }, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { fileInfo, _ := os.Stdin.Stat() - if logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { - return fmt.Errorf("the option -f - is intended to work with pipes") + if cli.flags.logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) { + return errors.New("the option -f - is intended to work with pipes") } + return nil }, } flags := cmd.Flags() - flags.StringP("file", "f", "", "Log file to test") - flags.StringP("dsn", "d", "", "DSN to test") - flags.StringP("log", "l", "", "Log line to test") - flags.StringP("type", "t", "", "Type of the acquisition to test") - flags.String("labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)") - flags.BoolP("verbose", "v", false, "Display individual changes") - flags.Bool("failures", false, "Only show failed lines") - flags.Bool("only-successful-parsers", false, "Only show successful parsers") - flags.String("crowdsec", "crowdsec", "Path to crowdsec") - flags.Bool("no-clean", false, "Don't clean runtime environment after tests") + flags.StringVarP(&cli.flags.logFile, "file", "f", "", "Log file to test") + flags.StringVarP(&cli.flags.dsn, "dsn", "d", "", "DSN to test") + flags.StringVarP(&cli.flags.logLine, "log", "l", "", "Log line to test") + flags.StringVarP(&cli.flags.logType, "type", "t", "", "Type of the acquisition to test") + flags.StringVar(&cli.flags.labels, "labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)") + flags.BoolVarP(&cli.flags.details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&cli.flags.skipOk, "failures", false, "Only show failed lines") + flags.BoolVar(&cli.flags.onlySuccessfulParsers, "only-successful-parsers", false, "Only show successful parsers") + flags.StringVar(&cli.flags.crowdsec, "crowdsec", "crowdsec", "Path to crowdsec") + flags.BoolVar(&cli.flags.noClean, "no-clean", false, "Don't clean runtime environment after tests") + + _ = cmd.MarkFlagRequired("type") + cmd.MarkFlagsOneRequired("log", "file", "dsn") return cmd } -func (cli cliExplain) run(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() +func (cli *cliExplain) run() error { + logFile := cli.flags.logFile + logLine := cli.flags.logLine + logType := cli.flags.logType + dsn := cli.flags.dsn + labels := cli.flags.labels + crowdsec := cli.flags.crowdsec - logFile, err := flags.GetString("file") - if err != nil { - return err - } - - dsn, err := flags.GetString("dsn") - if err != nil { - return err - } - - logLine, err := flags.GetString("log") - if err != nil { - return err - } - - logType, err := flags.GetString("type") - if err != nil { - return err - } - - opts := dumps.DumpOpts{} - - opts.Details, err = flags.GetBool("verbose") - if err != nil { - return err - } - - no_clean, err := flags.GetBool("no-clean") - if err != nil { - return err - } - - opts.SkipOk, err = flags.GetBool("failures") - if err != nil { - return err - } - - opts.ShowNotOkParsers, err = flags.GetBool("only-successful-parsers") - opts.ShowNotOkParsers = !opts.ShowNotOkParsers - if err != nil { - return err - } - - crowdsec, err := flags.GetString("crowdsec") - if err != nil { - return err - } - - labels, err := flags.GetString("labels") - if err != nil { - return err + opts := dumps.DumpOpts{ + Details: cli.flags.details, + SkipOk: cli.flags.skipOk, + ShowNotOkParsers: !cli.flags.onlySuccessfulParsers, } var f *os.File @@ -176,21 +133,25 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { // using empty string fallback to /tmp dir, err := os.MkdirTemp("", "cscli_explain") if err != nil { - return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err) + return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %w", err) } + defer func() { - if no_clean { + if cli.flags.noClean { return } + if _, err := os.Stat(dir); !os.IsNotExist(err) { if err := os.RemoveAll(dir); err != nil { log.Errorf("unable to delete temporary directory '%s': %s", dir, err) } } }() + // we create a temporary log file if a log line/stdin has been provided if logLine != "" || logFile == "-" { tmpFile := filepath.Join(dir, "cscli_test_tmp.log") + f, err = os.Create(tmpFile) if err != nil { return err @@ -204,22 +165,27 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { } else if logFile == "-" { reader := bufio.NewReader(os.Stdin) errCount := 0 + for { input, err := reader.ReadBytes('\n') if err != nil && errors.Is(err, io.EOF) { break } + if len(input) > 1 { _, err = f.Write(input) } + if err != nil || len(input) <= 1 { errCount++ } } + if errCount > 0 { log.Warnf("Failed to write %d lines to %s", errCount, tmpFile) } } + f.Close() // this is the file that was going to be read by crowdsec anyway logFile = tmpFile @@ -230,34 +196,43 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile) } + dsn = fmt.Sprintf("file://%s", absolutePath) - lineCount, err := GetLineCountForFile(absolutePath) + + lineCount, err := getLineCountForFile(absolutePath) if err != nil { return err } + log.Debugf("file %s has %d lines", absolutePath, lineCount) + if lineCount == 0 { return fmt.Errorf("the log file is empty: %s", absolutePath) } + if lineCount > 100 { log.Warnf("%s contains %d lines. This may take a lot of resources.", absolutePath, lineCount) } } if dsn == "" { - return fmt.Errorf("no acquisition (--file or --dsn) provided, can't run cscli test") + return errors.New("no acquisition (--file or --dsn) provided, can't run cscli test") } - cmdArgs := []string{"-c", ConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} + cmdArgs := []string{"-c", cli.configFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} + if labels != "" { log.Debugf("adding labels %s", labels) cmdArgs = append(cmdArgs, "-label", labels) } + crowdsecCmd := exec.Command(crowdsec, cmdArgs...) + output, err := crowdsecCmd.CombinedOutput() if err != nil { fmt.Println(string(output)) - return fmt.Errorf("fail to run crowdsec for test: %v", err) + + return fmt.Errorf("fail to run crowdsec for test: %w", err) } parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName) @@ -265,12 +240,12 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { parserDump, err := dumps.LoadParserDump(parserDumpFile) if err != nil { - return fmt.Errorf("unable to load parser dump result: %s", err) + return fmt.Errorf("unable to load parser dump result: %w", err) } bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile) if err != nil { - return fmt.Errorf("unable to load bucket dump result: %s", err) + return fmt.Errorf("unable to load bucket dump result: %w", err) } dumps.DumpTree(*parserDump, *bucketStateDump, opts) diff --git a/cmd/crowdsec-cli/hub.go b/cmd/crowdsec-cli/clihub/hub.go similarity index 55% rename from cmd/crowdsec-cli/hub.go rename to cmd/crowdsec-cli/clihub/hub.go index 3a2913f0513..22568355546 100644 --- a/cmd/crowdsec-cli/hub.go +++ b/cmd/crowdsec-cli/clihub/hub.go @@ -1,8 +1,10 @@ -package main +package clihub import ( + "context" "encoding/json" "fmt" + "io" "github.com/fatih/color" log "github.com/sirupsen/logrus" @@ -10,16 +12,23 @@ import ( "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type cliHub struct{} +type configGetter = func() *csconfig.Config -func NewCLIHub() *cliHub { - return &cliHub{} +type cliHub struct { + cfg configGetter } -func (cli cliHub) NewCommand() *cobra.Command { +func New(cfg configGetter) *cliHub { + return &cliHub{ + cfg: cfg, + } +} + +func (cli *cliHub) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "hub [action]", Short: "Manage hub index", @@ -34,26 +43,16 @@ cscli hub upgrade`, DisableAutoGenTag: true, } - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewUpdateCmd()) - cmd.AddCommand(cli.NewUpgradeCmd()) - cmd.AddCommand(cli.NewTypesCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newUpdateCmd()) + cmd.AddCommand(cli.newUpgradeCmd()) + cmd.AddCommand(cli.newTypesCmd()) return cmd } -func (cli cliHub) list(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - all, err := flags.GetBool("all") - if err != nil { - return err - } - - hub, err := require.Hub(csConfig, nil, log.StandardLogger()) - if err != nil { - return err - } +func (cli *cliHub) List(out io.Writer, hub *cwhub.Hub, all bool) error { + cfg := cli.cfg() for _, v := range hub.Warnings { log.Info(v) @@ -65,14 +64,16 @@ func (cli cliHub) list(cmd *cobra.Command, args []string) error { items := make(map[string][]*cwhub.Item) + var err error + for _, itemType := range cwhub.ItemTypes { - items[itemType], err = selectItems(hub, itemType, nil, !all) + items[itemType], err = SelectItems(hub, itemType, nil, !all) if err != nil { return err } } - err = listItems(color.Output, cwhub.ItemTypes, items, true) + err = ListItems(out, cfg.Cscli.Color, cwhub.ItemTypes, items, true, cfg.Cscli.Output) if err != nil { return err } @@ -80,31 +81,49 @@ func (cli cliHub) list(cmd *cobra.Command, args []string) error { return nil } -func (cli cliHub) NewListCmd() *cobra.Command { +func (cli *cliHub) newListCmd() *cobra.Command { + var all bool + cmd := &cobra.Command{ Use: "list [-a]", Short: "List all installed configurations", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.list, + RunE: func(_ *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + return cli.List(color.Output, hub, all) + }, } flags := cmd.Flags() - flags.BoolP("all", "a", false, "List disabled items as well") + flags.BoolVarP(&all, "all", "a", false, "List disabled items as well") return cmd } -func (cli cliHub) update(cmd *cobra.Command, args []string) error { - local := csConfig.Hub - remote := require.RemoteHub(csConfig) +func (cli *cliHub) update(ctx context.Context, withContent bool) error { + local := cli.cfg().Hub + remote := require.RemoteHub(ctx, cli.cfg()) + remote.EmbedItemContent = withContent // don't use require.Hub because if there is no index file, it would fail - hub, err := cwhub.NewHub(local, remote, true, log.StandardLogger()) + hub, err := cwhub.NewHub(local, remote, log.StandardLogger()) if err != nil { + return err + } + + if err := hub.Update(ctx); err != nil { return fmt.Errorf("failed to update hub: %w", err) } + if err := hub.Load(); err != nil { + return fmt.Errorf("failed to load hub: %w", err) + } + for _, v := range hub.Warnings { log.Info(v) } @@ -112,7 +131,9 @@ func (cli cliHub) update(cmd *cobra.Command, args []string) error { return nil } -func (cli cliHub) NewUpdateCmd() *cobra.Command { +func (cli *cliHub) newUpdateCmd() *cobra.Command { + withContent := false + cmd := &cobra.Command{ Use: "update", Short: "Download the latest index (catalog of available configurations)", @@ -121,37 +142,30 @@ Fetches the .index.json file from the hub, containing the list of available conf `, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.update, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.update(cmd.Context(), withContent) + }, } - return cmd -} - -func (cli cliHub) upgrade(cmd *cobra.Command, args []string) error { flags := cmd.Flags() + flags.BoolVar(&withContent, "with-content", false, "Download index with embedded item content") - force, err := flags.GetBool("force") - if err != nil { - return err - } + return cmd +} - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), log.StandardLogger()) +func (cli *cliHub) upgrade(ctx context.Context, force bool) error { + hub, err := require.Hub(cli.cfg(), require.RemoteHub(ctx, cli.cfg()), log.StandardLogger()) if err != nil { return err } for _, itemType := range cwhub.ItemTypes { - items, err := hub.GetInstalledItems(itemType) - if err != nil { - return err - } - updated := 0 log.Infof("Upgrading %s", itemType) - for _, item := range items { - didUpdate, err := item.Upgrade(force) + for _, item := range hub.GetInstalledByType(itemType, true) { + didUpdate, err := item.Upgrade(ctx, force) if err != nil { return err } @@ -167,7 +181,9 @@ func (cli cliHub) upgrade(cmd *cobra.Command, args []string) error { return nil } -func (cli cliHub) NewUpgradeCmd() *cobra.Command { +func (cli *cliHub) newUpgradeCmd() *cobra.Command { + var force bool + cmd := &cobra.Command{ Use: "upgrade", Short: "Upgrade all configurations to their latest version", @@ -176,17 +192,19 @@ Upgrade all configs installed from Crowdsec Hub. Run 'sudo cscli hub update' if `, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.upgrade, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.upgrade(cmd.Context(), force) + }, } flags := cmd.Flags() - flags.Bool("force", false, "Force upgrade: overwrite tainted and outdated files") + flags.BoolVar(&force, "force", false, "Force upgrade: overwrite tainted and outdated files") return cmd } -func (cli cliHub) types(cmd *cobra.Command, args []string) error { - switch csConfig.Cscli.Output { +func (cli *cliHub) types() error { + switch cli.cfg().Cscli.Output { case "human": s, err := yaml.Marshal(cwhub.ItemTypes) if err != nil { @@ -210,7 +228,7 @@ func (cli cliHub) types(cmd *cobra.Command, args []string) error { return nil } -func (cli cliHub) NewTypesCmd() *cobra.Command { +func (cli *cliHub) newTypesCmd() *cobra.Command { cmd := &cobra.Command{ Use: "types", Short: "List supported item types", @@ -219,7 +237,9 @@ List the types of supported hub items. `, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: cli.types, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.types() + }, } return cmd diff --git a/cmd/crowdsec-cli/item_metrics.go b/cmd/crowdsec-cli/clihub/item_metrics.go similarity index 78% rename from cmd/crowdsec-cli/item_metrics.go rename to cmd/crowdsec-cli/clihub/item_metrics.go index e6f27ae5d0d..f4af8f635db 100644 --- a/cmd/crowdsec-cli/item_metrics.go +++ b/cmd/crowdsec-cli/clihub/item_metrics.go @@ -1,8 +1,6 @@ -package main +package clihub import ( - "fmt" - "math" "net/http" "strconv" "strings" @@ -18,52 +16,59 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func ShowMetrics(hubItem *cwhub.Item) error { +func showMetrics(prometheusURL string, hubItem *cwhub.Item, wantColor string) error { switch hubItem.Type { case cwhub.PARSERS: - metrics := GetParserMetric(csConfig.Cscli.PrometheusUrl, hubItem.Name) - parserMetricsTable(color.Output, hubItem.Name, metrics) + metrics := getParserMetric(prometheusURL, hubItem.Name) + parserMetricsTable(color.Output, wantColor, hubItem.Name, metrics) case cwhub.SCENARIOS: - metrics := GetScenarioMetric(csConfig.Cscli.PrometheusUrl, hubItem.Name) - scenarioMetricsTable(color.Output, hubItem.Name, metrics) + metrics := getScenarioMetric(prometheusURL, hubItem.Name) + scenarioMetricsTable(color.Output, wantColor, hubItem.Name, metrics) case cwhub.COLLECTIONS: for _, sub := range hubItem.SubItems() { - if err := ShowMetrics(sub); err != nil { + if err := showMetrics(prometheusURL, sub, wantColor); err != nil { return err } } case cwhub.APPSEC_RULES: - metrics := GetAppsecRuleMetric(csConfig.Cscli.PrometheusUrl, hubItem.Name) - appsecMetricsTable(color.Output, hubItem.Name, metrics) + metrics := getAppsecRuleMetric(prometheusURL, hubItem.Name) + appsecMetricsTable(color.Output, wantColor, hubItem.Name, metrics) default: // no metrics for this item type } + return nil } -// GetParserMetric is a complete rip from prom2json -func GetParserMetric(url string, itemName string) map[string]map[string]int { +// getParserMetric is a complete rip from prom2json +func getParserMetric(url string, itemName string) map[string]map[string]int { stats := make(map[string]map[string]int) - result := GetPrometheusMetric(url) + result := getPrometheusMetric(url) for idx, fam := range result { if !strings.HasPrefix(fam.Name, "cs_") { continue } + log.Tracef("round %d", idx) + for _, m := range fam.Metrics { metric, ok := m.(prom2json.Metric) if !ok { log.Debugf("failed to convert metric to prom2json.Metric") continue } + name, ok := metric.Labels["name"] if !ok { log.Debugf("no name in Metric %v", metric.Labels) } + if name != itemName { continue } + source, ok := metric.Labels["source"] + if !ok { log.Debugf("no source in Metric %v", metric.Labels) } else { @@ -71,12 +76,15 @@ func GetParserMetric(url string, itemName string) map[string]map[string]int { source = srctype + ":" + source } } + value := m.(prom2json.Metric).Value + fval, err := strconv.ParseFloat(value, 32) if err != nil { log.Errorf("Unexpected int value %s : %s", value, err) continue } + ival := int(fval) switch fam.Name { @@ -119,10 +127,11 @@ func GetParserMetric(url string, itemName string) map[string]map[string]int { } } } + return stats } -func GetScenarioMetric(url string, itemName string) map[string]int { +func getScenarioMetric(url string, itemName string) map[string]int { stats := make(map[string]int) stats["instantiation"] = 0 @@ -131,31 +140,39 @@ func GetScenarioMetric(url string, itemName string) map[string]int { stats["pour"] = 0 stats["underflow"] = 0 - result := GetPrometheusMetric(url) + result := getPrometheusMetric(url) for idx, fam := range result { if !strings.HasPrefix(fam.Name, "cs_") { continue } + log.Tracef("round %d", idx) + for _, m := range fam.Metrics { metric, ok := m.(prom2json.Metric) if !ok { log.Debugf("failed to convert metric to prom2json.Metric") continue } + name, ok := metric.Labels["name"] + if !ok { log.Debugf("no name in Metric %v", metric.Labels) } + if name != itemName { continue } + value := m.(prom2json.Metric).Value + fval, err := strconv.ParseFloat(value, 32) if err != nil { log.Errorf("Unexpected int value %s : %s", value, err) continue } + ival := int(fval) switch fam.Name { @@ -174,31 +191,37 @@ func GetScenarioMetric(url string, itemName string) map[string]int { } } } + return stats } -func GetAppsecRuleMetric(url string, itemName string) map[string]int { +func getAppsecRuleMetric(url string, itemName string) map[string]int { stats := make(map[string]int) stats["inband_hits"] = 0 stats["outband_hits"] = 0 - results := GetPrometheusMetric(url) + results := getPrometheusMetric(url) for idx, fam := range results { if !strings.HasPrefix(fam.Name, "cs_") { continue } + log.Tracef("round %d", idx) + for _, m := range fam.Metrics { metric, ok := m.(prom2json.Metric) if !ok { log.Debugf("failed to convert metric to prom2json.Metric") continue } + name, ok := metric.Labels["rule_name"] + if !ok { log.Debugf("no rule_name in Metric %v", metric.Labels) } + if name != itemName { continue } @@ -209,11 +232,13 @@ func GetAppsecRuleMetric(url string, itemName string) map[string]int { } value := m.(prom2json.Metric).Value + fval, err := strconv.ParseFloat(value, 32) if err != nil { log.Errorf("Unexpected int value %s : %s", value, err) continue } + ival := int(fval) switch fam.Name { @@ -231,10 +256,11 @@ func GetAppsecRuleMetric(url string, itemName string) map[string]int { } } } + return stats } -func GetPrometheusMetric(url string) []*prom2json.Family { +func getPrometheusMetric(url string) []*prom2json.Family { mfChan := make(chan *dto.MetricFamily, 1024) // Start with the DefaultTransport for sane defaults. @@ -247,6 +273,7 @@ func GetPrometheusMetric(url string) []*prom2json.Family { go func() { defer trace.CatchPanic("crowdsec/GetPrometheusMetric") + err := prom2json.FetchMetricFamilies(url, mfChan, transport) if err != nil { log.Fatalf("failed to fetch prometheus metrics : %v", err) @@ -257,41 +284,8 @@ func GetPrometheusMetric(url string) []*prom2json.Family { for mf := range mfChan { result = append(result, prom2json.NewFamily(mf)) } + log.Debugf("Finished reading prometheus output, %d entries", len(result)) return result } - -type unit struct { - value int64 - symbol string -} - -var ranges = []unit{ - {value: 1e18, symbol: "E"}, - {value: 1e15, symbol: "P"}, - {value: 1e12, symbol: "T"}, - {value: 1e9, symbol: "G"}, - {value: 1e6, symbol: "M"}, - {value: 1e3, symbol: "k"}, - {value: 1, symbol: ""}, -} - -func formatNumber(num int) string { - goodUnit := unit{} - - for _, u := range ranges { - if int64(num) >= u.value { - goodUnit = u - break - } - } - - if goodUnit.value == 1 { - return fmt.Sprintf("%d%s", num, goodUnit.symbol) - } - - res := math.Round(float64(num)/float64(goodUnit.value)*100) / 100 - - return fmt.Sprintf("%.2f%s", res, goodUnit.symbol) -} diff --git a/cmd/crowdsec-cli/items.go b/cmd/crowdsec-cli/clihub/items.go similarity index 74% rename from cmd/crowdsec-cli/items.go rename to cmd/crowdsec-cli/clihub/items.go index a1d079747fa..f86fe65a2a1 100644 --- a/cmd/crowdsec-cli/items.go +++ b/cmd/crowdsec-cli/clihub/items.go @@ -1,4 +1,4 @@ -package main +package clihub import ( "encoding/csv" @@ -16,8 +16,13 @@ import ( ) // selectItems returns a slice of items of a given type, selected by name and sorted by case-insensitive name -func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly bool) ([]*cwhub.Item, error) { - itemNames := hub.GetItemNames(itemType) +func SelectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly bool) ([]*cwhub.Item, error) { + allItems := hub.GetItemsByType(itemType, true) + + itemNames := make([]string, len(allItems)) + for idx, item := range allItems { + itemNames[idx] = item.Name + } notExist := []string{} @@ -38,7 +43,7 @@ func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly b installedOnly = false } - items := make([]*cwhub.Item, 0, len(itemNames)) + wantedItems := make([]*cwhub.Item, 0, len(itemNames)) for _, itemName := range itemNames { item := hub.GetItem(itemType, itemName) @@ -46,16 +51,14 @@ func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly b continue } - items = append(items, item) + wantedItems = append(wantedItems, item) } - cwhub.SortItemSlice(items) - - return items, nil + return wantedItems, nil } -func listItems(out io.Writer, itemTypes []string, items map[string][]*cwhub.Item, omitIfEmpty bool) error { - switch csConfig.Cscli.Output { +func ListItems(out io.Writer, wantColor string, itemTypes []string, items map[string][]*cwhub.Item, omitIfEmpty bool, output string) error { + switch output { case "human": nothingToDisplay := true @@ -64,7 +67,7 @@ func listItems(out io.Writer, itemTypes []string, items map[string][]*cwhub.Item continue } - listHubItemTable(out, "\n"+strings.ToUpper(itemType), items[itemType]) + listHubItemTable(out, wantColor, "\n"+strings.ToUpper(itemType), items[itemType]) nothingToDisplay = false } @@ -103,7 +106,7 @@ func listItems(out io.Writer, itemTypes []string, items map[string][]*cwhub.Item x, err := json.MarshalIndent(hubStatus, "", " ") if err != nil { - return fmt.Errorf("failed to unmarshal: %w", err) + return fmt.Errorf("failed to parse: %w", err) } out.Write(x) @@ -116,7 +119,7 @@ func listItems(out io.Writer, itemTypes []string, items map[string][]*cwhub.Item } if err := csvwriter.Write(header); err != nil { - return fmt.Errorf("failed to write header: %s", err) + return fmt.Errorf("failed to write header: %w", err) } for _, itemType := range itemTypes { @@ -132,38 +135,36 @@ func listItems(out io.Writer, itemTypes []string, items map[string][]*cwhub.Item } if err := csvwriter.Write(row); err != nil { - return fmt.Errorf("failed to write raw output: %s", err) + return fmt.Errorf("failed to write raw output: %w", err) } } } csvwriter.Flush() - default: - return fmt.Errorf("unknown output format '%s'", csConfig.Cscli.Output) } return nil } -func InspectItem(item *cwhub.Item, showMetrics bool) error { - switch csConfig.Cscli.Output { +func InspectItem(item *cwhub.Item, wantMetrics bool, output string, prometheusURL string, wantColor string) error { + switch output { case "human", "raw": enc := yaml.NewEncoder(os.Stdout) enc.SetIndent(2) if err := enc.Encode(item); err != nil { - return fmt.Errorf("unable to encode item: %s", err) + return fmt.Errorf("unable to encode item: %w", err) } case "json": b, err := json.MarshalIndent(*item, "", " ") if err != nil { - return fmt.Errorf("unable to marshal item: %s", err) + return fmt.Errorf("unable to serialize item: %w", err) } fmt.Print(string(b)) } - if csConfig.Cscli.Output != "human" { + if output != "human" { return nil } @@ -173,10 +174,10 @@ func InspectItem(item *cwhub.Item, showMetrics bool) error { fmt.Println() } - if showMetrics { + if wantMetrics { fmt.Printf("\nCurrent metrics: \n") - if err := ShowMetrics(item); err != nil { + if err := showMetrics(prometheusURL, item, wantColor); err != nil { return err } } diff --git a/cmd/crowdsec-cli/clihub/utils_table.go b/cmd/crowdsec-cli/clihub/utils_table.go new file mode 100644 index 00000000000..98f14341b10 --- /dev/null +++ b/cmd/crowdsec-cli/clihub/utils_table.go @@ -0,0 +1,85 @@ +package clihub + +import ( + "fmt" + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func listHubItemTable(out io.Writer, wantColor string, title string, items []*cwhub.Item) { + t := cstable.NewLight(out, wantColor).Writer + t.AppendHeader(table.Row{"Name", fmt.Sprintf("%v Status", emoji.Package), "Version", "Local Path"}) + + for _, item := range items { + status := fmt.Sprintf("%v %s", item.State.Emoji(), item.State.Text()) + t.AppendRow(table.Row{item.Name, status, item.State.LocalVersion, item.State.LocalPath}) + } + + io.WriteString(out, title+"\n") + io.WriteString(out, t.Render()+"\n") +} + +func appsecMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { + t := cstable.NewLight(out, wantColor).Writer + t.AppendHeader(table.Row{"Inband Hits", "Outband Hits"}) + + t.AppendRow(table.Row{ + strconv.Itoa(metrics["inband_hits"]), + strconv.Itoa(metrics["outband_hits"]), + }) + + io.WriteString(out, fmt.Sprintf("\n - (AppSec Rule) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") +} + +func scenarioMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { + if metrics["instantiation"] == 0 { + return + } + + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Current Count", "Overflows", "Instantiated", "Poured", "Expired"}) + + t.AppendRow(table.Row{ + strconv.Itoa(metrics["curr_count"]), + strconv.Itoa(metrics["overflow"]), + strconv.Itoa(metrics["instantiation"]), + strconv.Itoa(metrics["pour"]), + strconv.Itoa(metrics["underflow"]), + }) + + io.WriteString(out, fmt.Sprintf("\n - (Scenario) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") +} + +func parserMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]map[string]int) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Parsers", "Hits", "Parsed", "Unparsed"}) + + // don't show table if no hits + showTable := false + + for source, stats := range metrics { + if stats["hits"] > 0 { + t.AppendRow(table.Row{ + source, + strconv.Itoa(stats["hits"]), + strconv.Itoa(stats["parsed"]), + strconv.Itoa(stats["unparsed"]), + }) + + showTable = true + } + } + + if showTable { + io.WriteString(out, fmt.Sprintf("\n - (Parser) %s:\n", itemName)) + io.WriteString(out, t.Render()+"\n") + } +} diff --git a/cmd/crowdsec-cli/clihubtest/clean.go b/cmd/crowdsec-cli/clihubtest/clean.go new file mode 100644 index 00000000000..e3b40b6bd57 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/clean.go @@ -0,0 +1,31 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newCleanCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clean", + Short: "clean [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/coverage.go b/cmd/crowdsec-cli/clihubtest/coverage.go new file mode 100644 index 00000000000..5a4f231caf5 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/coverage.go @@ -0,0 +1,166 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "math" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +// getCoverage returns the coverage and the percentage of tests that passed +func getCoverage(show bool, getCoverageFunc func() ([]hubtest.Coverage, error)) ([]hubtest.Coverage, int, error) { + if !show { + return nil, 0, nil + } + + coverage, err := getCoverageFunc() + if err != nil { + return nil, 0, fmt.Errorf("while getting coverage: %w", err) + } + + tested := 0 + + for _, test := range coverage { + if test.TestsCount > 0 { + tested++ + } + } + + // keep coverage 0 if there's no tests? + percent := 0 + if len(coverage) > 0 { + percent = int(math.Round((float64(tested) / float64(len(coverage)) * 100))) + } + + return coverage, percent, nil +} + +func (cli *cliHubTest) coverage(showScenarioCov bool, showParserCov bool, showAppsecCov bool, showOnlyPercent bool) error { + cfg := cli.cfg() + + // for this one we explicitly don't do for appsec + if err := HubTest.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + + var err error + + // if all are false (flag by default), show them + if !showParserCov && !showScenarioCov && !showAppsecCov { + showParserCov = true + showScenarioCov = true + showAppsecCov = true + } + + parserCoverage, parserCoveragePercent, err := getCoverage(showParserCov, HubTest.GetParsersCoverage) + if err != nil { + return err + } + + scenarioCoverage, scenarioCoveragePercent, err := getCoverage(showScenarioCov, HubTest.GetScenariosCoverage) + if err != nil { + return err + } + + appsecRuleCoverage, appsecRuleCoveragePercent, err := getCoverage(showAppsecCov, HubTest.GetAppsecCoverage) + if err != nil { + return err + } + + if showOnlyPercent { + switch { + case showParserCov: + fmt.Printf("parsers=%d%%", parserCoveragePercent) + case showScenarioCov: + fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) + case showAppsecCov: + fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) + } + + return nil + } + + switch cfg.Cscli.Output { + case "human": + if showParserCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Parser", "Status", "Number of tests"}, parserCoverage) + } + + if showScenarioCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Scenario", "Status", "Number of tests"}, parserCoverage) + } + + if showAppsecCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Appsec Rule", "Status", "Number of tests"}, parserCoverage) + } + + fmt.Println() + + if showParserCov { + fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) + } + + if showScenarioCov { + fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) + } + + if showAppsecCov { + fmt.Printf("APPSEC RULES : %d%% of coverage\n", appsecRuleCoveragePercent) + } + case "json": + dump, err := json.MarshalIndent(parserCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(scenarioCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(appsecRuleCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + default: + return errors.New("only human/json output modes are supported") + } + + return nil +} + +func (cli *cliHubTest) newCoverageCmd() *cobra.Command { + var ( + showParserCov bool + showScenarioCov bool + showOnlyPercent bool + showAppsecCov bool + ) + + cmd := &cobra.Command{ + Use: "coverage", + Short: "coverage", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.coverage(showScenarioCov, showParserCov, showAppsecCov, showOnlyPercent) + }, + } + + cmd.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") + cmd.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") + cmd.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") + cmd.PersistentFlags().BoolVar(&showAppsecCov, "appsec", false, "Show only appsec coverage") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/create.go b/cmd/crowdsec-cli/clihubtest/create.go new file mode 100644 index 00000000000..3822bed8903 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/create.go @@ -0,0 +1,158 @@ +package clihubtest + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "text/template" + + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newCreateCmd() *cobra.Command { + var ( + ignoreParsers bool + labels map[string]string + logType string + ) + + parsers := []string{} + postoverflows := []string{} + scenarios := []string{} + + cmd := &cobra.Command{ + Use: "create", + Short: "create [test_name]", + Example: `cscli hubtest create my-awesome-test --type syslog +cscli hubtest create my-nginx-custom-test --type nginx +cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + testName := args[0] + testPath := filepath.Join(hubPtr.HubTestPath, testName) + if _, err := os.Stat(testPath); os.IsExist(err) { + return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) + } + + if isAppsecTest { + logType = "appsec" + } + + if logType == "" { + return errors.New("please provide a type (--type) for the test") + } + + if err := os.MkdirAll(testPath, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) + } + + configFilePath := filepath.Join(testPath, "config.yaml") + + configFileData := &hubtest.HubTestItemConfig{} + if logType == "appsec" { + // create empty nuclei template file + nucleiFileName := testName + ".yaml" + nucleiFilePath := filepath.Join(testPath, nucleiFileName) + + nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0o755) + if err != nil { + return err + } + + ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) + if ntpl == nil { + return errors.New("unable to parse nuclei template") + } + ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) + nucleiFile.Close() + configFileData.AppsecRules = []string{"./appsec-rules//your_rule_here.yaml"} + configFileData.NucleiTemplate = nucleiFileName + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Config File : %s\n", configFilePath) + fmt.Printf(" Nuclei Template : %s\n", nucleiFilePath) + } else { + // create empty log file + logFileName := testName + ".log" + logFilePath := filepath.Join(testPath, logFileName) + logFile, err := os.Create(logFilePath) + if err != nil { + return err + } + logFile.Close() + + // create empty parser assertion file + parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) + parserAssertFile, err := os.Create(parserAssertFilePath) + if err != nil { + return err + } + parserAssertFile.Close() + // create empty scenario assertion file + scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) + scenarioAssertFile, err := os.Create(scenarioAssertFilePath) + if err != nil { + return err + } + scenarioAssertFile.Close() + + parsers = append(parsers, "crowdsecurity/syslog-logs") + parsers = append(parsers, "crowdsecurity/dateparse-enrich") + + if len(scenarios) == 0 { + scenarios = append(scenarios, "") + } + + if len(postoverflows) == 0 { + postoverflows = append(postoverflows, "") + } + configFileData.Parsers = parsers + configFileData.Scenarios = scenarios + configFileData.PostOverflows = postoverflows + configFileData.LogFile = logFileName + configFileData.LogType = logType + configFileData.IgnoreParsers = ignoreParsers + configFileData.Labels = labels + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) + fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) + fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) + fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) + } + + fd, err := os.Create(configFilePath) + if err != nil { + return fmt.Errorf("open: %w", err) + } + data, err := yaml.Marshal(configFileData) + if err != nil { + return fmt.Errorf("serialize: %w", err) + } + _, err = fd.Write(data) + if err != nil { + return fmt.Errorf("write: %w", err) + } + if err := fd.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") + cmd.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") + cmd.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") + cmd.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") + cmd.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/eval.go b/cmd/crowdsec-cli/clihubtest/eval.go new file mode 100644 index 00000000000..83e9eae9c15 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/eval.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newEvalCmd() *cobra.Command { + var evalExpression string + + cmd := &cobra.Command{ + Use: "eval", + Short: "eval [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) + } + + output, err := test.ParserAssert.EvalExpression(evalExpression) + if err != nil { + return err + } + + fmt.Print(output) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/explain.go b/cmd/crowdsec-cli/clihubtest/explain.go new file mode 100644 index 00000000000..dbe10fa7ec0 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/explain.go @@ -0,0 +1,76 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/dumps" +) + +func (cli *cliHubTest) explain(testName string, details bool, skipOk bool) error { + test, err := HubTest.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { + return fmt.Errorf("unable to load parser result after run: %w", err) + } + } + + err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { + return fmt.Errorf("unable to load scenario result after run: %w", err) + } + } + + opts := dumps.DumpOpts{ + Details: details, + SkipOk: skipOk, + } + + dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) + + return nil +} + +func (cli *cliHubTest) newExplainCmd() *cobra.Command { + var ( + details bool + skipOk bool + ) + + cmd := &cobra.Command{ + Use: "explain", + Short: "explain [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + if err := cli.explain(testName, details, skipOk); err != nil { + return err + } + } + + return nil + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&skipOk, "failures", false, "Only show failed lines") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/hubtest.go b/cmd/crowdsec-cli/clihubtest/hubtest.go new file mode 100644 index 00000000000..3420e21e1e2 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/hubtest.go @@ -0,0 +1,81 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +type configGetter func() *csconfig.Config + +var ( + HubTest hubtest.HubTest + HubAppsecTests hubtest.HubTest + hubPtr *hubtest.HubTest + isAppsecTest bool +) + +type cliHubTest struct { + cfg configGetter +} + +func New(cfg configGetter) *cliHubTest { + return &cliHubTest{ + cfg: cfg, + } +} + +func (cli *cliHubTest) NewCommand() *cobra.Command { + var ( + hubPath string + crowdsecPath string + cscliPath string + ) + + cmd := &cobra.Command{ + Use: "hubtest", + Short: "Run functional tests on hub configurations", + Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + var err error + HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, false) + if err != nil { + return fmt.Errorf("unable to load hubtest: %+v", err) + } + + HubAppsecTests, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, true) + if err != nil { + return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) + } + + // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests + hubPtr = &HubTest + if isAppsecTest { + hubPtr = &HubAppsecTests + } + + return nil + }, + } + + cmd.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") + cmd.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") + cmd.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") + cmd.PersistentFlags().BoolVar(&isAppsecTest, "appsec", false, "Command relates to appsec tests") + + cmd.AddCommand(cli.newCreateCmd()) + cmd.AddCommand(cli.newRunCmd()) + cmd.AddCommand(cli.newCleanCmd()) + cmd.AddCommand(cli.newInfoCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newCoverageCmd()) + cmd.AddCommand(cli.newEvalCmd()) + cmd.AddCommand(cli.newExplainCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/info.go b/cmd/crowdsec-cli/clihubtest/info.go new file mode 100644 index 00000000000..a5d760eea01 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/info.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newInfoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "info", + Short: "info [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + fmt.Println() + fmt.Printf(" Test name : %s\n", test.Name) + fmt.Printf(" Test path : %s\n", test.Path) + if isAppsecTest { + fmt.Printf(" Nuclei Template : %s\n", test.Config.NucleiTemplate) + fmt.Printf(" Appsec Rules : %s\n", strings.Join(test.Config.AppsecRules, ", ")) + } else { + fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) + fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) + fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) + } + fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/list.go b/cmd/crowdsec-cli/clihubtest/list.go new file mode 100644 index 00000000000..3e76824a18e --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/list.go @@ -0,0 +1,42 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %w", err) + } + + switch cfg.Cscli.Output { + case "human": + hubTestListTable(color.Output, cfg.Cscli.Color, hubPtr.Tests) + case "json": + j, err := json.MarshalIndent(hubPtr.Tests, " ", " ") + if err != nil { + return err + } + fmt.Println(string(j)) + default: + return errors.New("only human/json output modes are supported") + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/run.go b/cmd/crowdsec-cli/clihubtest/run.go new file mode 100644 index 00000000000..31cceb81884 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/run.go @@ -0,0 +1,213 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) run(runAll bool, nucleiTargetHost string, appSecHost string, args []string) error { + cfg := cli.cfg() + + if !runAll && len(args) == 0 { + return errors.New("please provide test to run or --all flag") + } + + hubPtr.NucleiTargetHost = nucleiTargetHost + hubPtr.AppSecHost = appSecHost + + if runAll { + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + } else { + for _, testName := range args { + _, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + } + } + + // set timezone to avoid DST issues + os.Setenv("TZ", "UTC") + + for _, test := range hubPtr.Tests { + if cfg.Cscli.Output == "human" { + log.Infof("Running test '%s'", test.Name) + } + + err := test.Run() + if err != nil { + log.Errorf("running test '%s' failed: %+v", test.Name, err) + } + } + + return nil +} + +func printParserFailures(test *hubtest.HubTestItem) { + if len(test.ParserAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) + + for _, fail := range test.ParserAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func printScenarioFailures(test *hubtest.HubTestItem) { + if len(test.ScenarioAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) + + for _, fail := range test.ScenarioAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func (cli *cliHubTest) newRunCmd() *cobra.Command { + var ( + noClean bool + runAll bool + forceClean bool + nucleiTargetHost string + appSecHost string + ) + + cmd := &cobra.Command{ + Use: "run", + Short: "run [test_name]", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.run(runAll, nucleiTargetHost, appSecHost, args) + }, + PersistentPostRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + success := true + testResult := make(map[string]bool) + for _, test := range hubPtr.Tests { + if test.AutoGen && !isAppsecTest { + if test.ParserAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) + fmt.Println() + fmt.Println(test.ParserAssert.AutoGenAssertData) + } + if test.ScenarioAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) + fmt.Println() + fmt.Println(test.ScenarioAssert.AutoGenAssertData) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return fmt.Errorf("please fill your assert file(s) for test '%s', exiting", test.Name) + } + testResult[test.Name] = test.Success + if test.Success { + if cfg.Cscli.Output == "human" { + log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } else { + success = false + cleanTestEnv := false + if cfg.Cscli.Output == "human" { + printParserFailures(test) + printScenarioFailures(test) + if !forceClean && !noClean { + prompt := &survey.Confirm{ + Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), + Default: true, + } + if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { + return fmt.Errorf("unable to ask to remove runtime folder: %w", err) + } + } + } + + if cleanTestEnv || forceClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } + } + + switch cfg.Cscli.Output { + case "human": + hubTestResultTable(color.Output, cfg.Cscli.Color, testResult) + case "json": + jsonResult := make(map[string][]string, 0) + jsonResult["success"] = make([]string, 0) + jsonResult["fail"] = make([]string, 0) + for testName, success := range testResult { + if success { + jsonResult["success"] = append(jsonResult["success"], testName) + } else { + jsonResult["fail"] = append(jsonResult["fail"], testName) + } + } + jsonStr, err := json.Marshal(jsonResult) + if err != nil { + return fmt.Errorf("unable to json test result: %w", err) + } + fmt.Println(string(jsonStr)) + default: + return errors.New("only human/json output modes are supported") + } + + if !success { + return errors.New("some tests failed") + } + + return nil + }, + } + + cmd.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") + cmd.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") + cmd.Flags().StringVar(&nucleiTargetHost, "target", hubtest.DefaultNucleiTarget, "Target for AppSec Test") + cmd.Flags().StringVar(&appSecHost, "host", hubtest.DefaultAppsecHost, "Address to expose AppSec for hubtest") + cmd.Flags().BoolVar(&runAll, "all", false, "Run all tests") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/table.go b/cmd/crowdsec-cli/clihubtest/table.go new file mode 100644 index 00000000000..2a105a1f5c1 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/table.go @@ -0,0 +1,64 @@ +package clihubtest + +import ( + "fmt" + "io" + + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func hubTestResultTable(out io.Writer, wantColor string, testResult map[string]bool) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders("Test", "Result") + t.SetHeaderAlignment(text.AlignLeft) + t.SetAlignment(text.AlignLeft) + + for testName, success := range testResult { + status := emoji.CheckMarkButton + if !success { + status = emoji.CrossMark + } + + t.AddRow(testName, status) + } + + t.Render() +} + +func hubTestListTable(out io.Writer, wantColor string, tests []*hubtest.HubTestItem) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders("Name", "Path") + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft) + + for _, test := range tests { + t.AddRow(test.Name, test.Path) + } + + t.Render() +} + +func hubTestCoverageTable(out io.Writer, wantColor string, headers []string, coverage []hubtest.Coverage) { + t := cstable.NewLight(out, wantColor) + t.SetHeaders(headers...) + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) + + parserTested := 0 + + for _, test := range coverage { + status := emoji.RedCircle + if test.TestsCount > 0 { + status = emoji.GreenCircle + parserTested++ + } + + t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) + } + + t.Render() +} diff --git a/cmd/crowdsec-cli/hubappsec.go b/cmd/crowdsec-cli/cliitem/appsec.go similarity index 87% rename from cmd/crowdsec-cli/hubappsec.go rename to cmd/crowdsec-cli/cliitem/appsec.go index ff41ad5f9ad..44afa2133bd 100644 --- a/cmd/crowdsec-cli/hubappsec.go +++ b/cmd/crowdsec-cli/cliitem/appsec.go @@ -1,4 +1,4 @@ -package main +package cliitem import ( "fmt" @@ -13,8 +13,9 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIAppsecConfig() *cliItem { +func NewAppsecConfig(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.APPSEC_CONFIGS, singular: "appsec-config", oneOrMore: "appsec-config(s)", @@ -46,10 +47,10 @@ cscli appsec-configs list crowdsecurity/vpatch`, } } -func NewCLIAppsecRule() *cliItem { +func NewAppsecRule(cfg configGetter) *cliItem { inspectDetail := func(item *cwhub.Item) error { // Only show the converted rules in human mode - if csConfig.Cscli.Output != "human" { + if cfg().Cscli.Output != "human" { return nil } @@ -57,11 +58,11 @@ func NewCLIAppsecRule() *cliItem { yamlContent, err := os.ReadFile(item.State.LocalPath) if err != nil { - return fmt.Errorf("unable to read file %s : %s", item.State.LocalPath, err) + return fmt.Errorf("unable to read file %s: %w", item.State.LocalPath, err) } if err := yaml.Unmarshal(yamlContent, &appsecRule); err != nil { - return fmt.Errorf("unable to unmarshal yaml file %s : %s", item.State.LocalPath, err) + return fmt.Errorf("unable to parse yaml file %s: %w", item.State.LocalPath, err) } for _, ruleType := range appsec_rule.SupportedTypes() { @@ -70,7 +71,7 @@ func NewCLIAppsecRule() *cliItem { for _, rule := range appsecRule.Rules { convertedRule, _, err := rule.Convert(ruleType, appsecRule.Name) if err != nil { - return fmt.Errorf("unable to convert rule %s : %s", rule.Name, err) + return fmt.Errorf("unable to convert rule %s: %w", rule.Name, err) } fmt.Println(convertedRule) @@ -88,6 +89,7 @@ func NewCLIAppsecRule() *cliItem { } return &cliItem{ + cfg: cfg, name: "appsec-rules", singular: "appsec-rule", oneOrMore: "appsec-rule(s)", diff --git a/cmd/crowdsec-cli/hubcollection.go b/cmd/crowdsec-cli/cliitem/collection.go similarity index 93% rename from cmd/crowdsec-cli/hubcollection.go rename to cmd/crowdsec-cli/cliitem/collection.go index dee9a0b9e66..ea91c1e537a 100644 --- a/cmd/crowdsec-cli/hubcollection.go +++ b/cmd/crowdsec-cli/cliitem/collection.go @@ -1,11 +1,12 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLICollection() *cliItem { +func NewCollection(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.COLLECTIONS, singular: "collection", oneOrMore: "collection(s)", diff --git a/cmd/crowdsec-cli/hubcontext.go b/cmd/crowdsec-cli/cliitem/context.go similarity index 93% rename from cmd/crowdsec-cli/hubcontext.go rename to cmd/crowdsec-cli/cliitem/context.go index 630dbb2f7b6..7d110b8203d 100644 --- a/cmd/crowdsec-cli/hubcontext.go +++ b/cmd/crowdsec-cli/cliitem/context.go @@ -1,11 +1,12 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIContext() *cliItem { +func NewContext(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.CONTEXTS, singular: "context", oneOrMore: "context(s)", diff --git a/cmd/crowdsec-cli/hubscenario.go b/cmd/crowdsec-cli/cliitem/hubscenario.go similarity index 93% rename from cmd/crowdsec-cli/hubscenario.go rename to cmd/crowdsec-cli/cliitem/hubscenario.go index 1de2182bfc5..a5e854b3c82 100644 --- a/cmd/crowdsec-cli/hubscenario.go +++ b/cmd/crowdsec-cli/cliitem/hubscenario.go @@ -1,11 +1,12 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIScenario() *cliItem { +func NewScenario(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.SCENARIOS, singular: "scenario", oneOrMore: "scenario(s)", diff --git a/cmd/crowdsec-cli/cliitem/item.go b/cmd/crowdsec-cli/cliitem/item.go new file mode 100644 index 00000000000..28828eb9c95 --- /dev/null +++ b/cmd/crowdsec-cli/cliitem/item.go @@ -0,0 +1,550 @@ +package cliitem + +import ( + "cmp" + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/fatih/color" + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +type cliHelp struct { + // Example is required, the others have a default value + // generated from the item type + use string + short string + long string + example string +} + +type configGetter func() *csconfig.Config + +type cliItem struct { + cfg configGetter + name string // plural, as used in the hub index + singular string + oneOrMore string // parenthetical pluralizaion: "parser(s)" + help cliHelp + installHelp cliHelp + removeHelp cliHelp + upgradeHelp cliHelp + inspectHelp cliHelp + inspectDetail func(item *cwhub.Item) error + listHelp cliHelp +} + +func (cli cliItem) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: cmp.Or(cli.help.use, cli.name+" [item]..."), + Short: cmp.Or(cli.help.short, "Manage hub "+cli.name), + Long: cli.help.long, + Example: cli.help.example, + Args: cobra.MinimumNArgs(1), + Aliases: []string{cli.singular}, + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newInstallCmd()) + cmd.AddCommand(cli.newRemoveCmd()) + cmd.AddCommand(cli.newUpgradeCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newListCmd()) + + return cmd +} + +func (cli cliItem) install(ctx context.Context, args []string, downloadOnly bool, force bool, ignoreError bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + for _, name := range args { + item := hub.GetItem(cli.name, name) + if item == nil { + msg := suggestNearestMessage(hub, cli.name, name) + if !ignoreError { + return errors.New(msg) + } + + log.Error(msg) + + continue + } + + if err := item.Install(ctx, force, downloadOnly); err != nil { + if !ignoreError { + return fmt.Errorf("error while installing '%s': %w", item.Name, err) + } + + log.Errorf("Error while installing '%s': %s", item.Name, err) + } + } + + log.Info(reload.Message) + + return nil +} + +func (cli cliItem) newInstallCmd() *cobra.Command { + var ( + downloadOnly bool + force bool + ignoreError bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.installHelp.use, "install [item]..."), + Short: cmp.Or(cli.installHelp.short, "Install given "+cli.oneOrMore), + Long: cmp.Or(cli.installHelp.long, fmt.Sprintf("Fetch and install one or more %s from the hub", cli.name)), + Example: cli.installHelp.example, + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compAllItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.install(cmd.Context(), args, downloadOnly, force, ignoreError) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&downloadOnly, "download-only", "d", false, "Only download packages, don't enable") + flags.BoolVar(&force, "force", false, "Force install: overwrite tainted and outdated files") + flags.BoolVar(&ignoreError, "ignore", false, "Ignore errors when installing multiple "+cli.name) + + return cmd +} + +// return the names of the installed parents of an item, used to check if we can remove it +func istalledParentNames(item *cwhub.Item) []string { + ret := make([]string, 0) + + for _, parent := range item.Ancestors() { + if parent.State.Installed { + ret = append(ret, parent.Name) + } + } + + return ret +} + +func (cli cliItem) remove(args []string, purge bool, force bool, all bool) error { + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + if all { + itemGetter := hub.GetInstalledByType + if purge { + itemGetter = hub.GetItemsByType + } + + removed := 0 + + for _, item := range itemGetter(cli.name, true) { + didRemove, err := item.Remove(purge, force) + if err != nil { + return err + } + + if didRemove { + log.Infof("Removed %s", item.Name) + + removed++ + } + } + + log.Infof("Removed %d %s", removed, cli.name) + + if removed > 0 { + log.Info(reload.Message) + } + + return nil + } + + if len(args) == 0 { + return fmt.Errorf("specify at least one %s to remove or '--all'", cli.singular) + } + + removed := 0 + + for _, itemName := range args { + item := hub.GetItem(cli.name, itemName) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) + } + + parents := istalledParentNames(item) + + if !force && len(parents) > 0 { + log.Warningf("%s belongs to collections: %s", item.Name, parents) + log.Warningf("Run 'sudo cscli %s remove %s --force' if you want to force remove this %s", item.Type, item.Name, cli.singular) + + continue + } + + didRemove, err := item.Remove(purge, force) + if err != nil { + return err + } + + if didRemove { + log.Infof("Removed %s", item.Name) + + removed++ + } + } + + log.Infof("Removed %d %s", removed, cli.name) + + if removed > 0 { + log.Info(reload.Message) + } + + return nil +} + +func (cli cliItem) newRemoveCmd() *cobra.Command { + var ( + purge bool + force bool + all bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.removeHelp.use, "remove [item]..."), + Short: cmp.Or(cli.removeHelp.short, "Remove given "+cli.oneOrMore), + Long: cmp.Or(cli.removeHelp.long, "Remove one or more "+cli.name), + Example: cli.removeHelp.example, + Aliases: []string{"delete"}, + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(_ *cobra.Command, args []string) error { + return cli.remove(args, purge, force, all) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&purge, "purge", false, "Delete source file too") + flags.BoolVar(&force, "force", false, "Force remove: remove tainted and outdated files") + flags.BoolVar(&all, "all", false, "Remove all the "+cli.name) + + return cmd +} + +func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + if all { + updated := 0 + + for _, item := range hub.GetInstalledByType(cli.name, true) { + didUpdate, err := item.Upgrade(ctx, force) + if err != nil { + return err + } + + if didUpdate { + updated++ + } + } + + log.Infof("Updated %d %s", updated, cli.name) + + if updated > 0 { + log.Info(reload.Message) + } + + return nil + } + + if len(args) == 0 { + return fmt.Errorf("specify at least one %s to upgrade or '--all'", cli.singular) + } + + updated := 0 + + for _, itemName := range args { + item := hub.GetItem(cli.name, itemName) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) + } + + didUpdate, err := item.Upgrade(ctx, force) + if err != nil { + return err + } + + if didUpdate { + log.Infof("Updated %s", item.Name) + + updated++ + } + } + + if updated > 0 { + log.Info(reload.Message) + } + + return nil +} + +func (cli cliItem) newUpgradeCmd() *cobra.Command { + var ( + all bool + force bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.upgradeHelp.use, "upgrade [item]..."), + Short: cmp.Or(cli.upgradeHelp.short, "Upgrade given "+cli.oneOrMore), + Long: cmp.Or(cli.upgradeHelp.long, fmt.Sprintf("Fetch and upgrade one or more %s from the hub", cli.name)), + Example: cli.upgradeHelp.example, + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.upgrade(cmd.Context(), args, force, all) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&all, "all", "a", false, "Upgrade all the "+cli.name) + flags.BoolVar(&force, "force", false, "Force upgrade: overwrite tainted and outdated files") + + return cmd +} + +func (cli cliItem) inspect(ctx context.Context, args []string, url string, diff bool, rev bool, noMetrics bool) error { + cfg := cli.cfg() + + if rev && !diff { + return errors.New("--rev can only be used with --diff") + } + + if url != "" { + cfg.Cscli.PrometheusUrl = url + } + + remote := (*cwhub.RemoteHubCfg)(nil) + + if diff { + remote = require.RemoteHub(ctx, cfg) + } + + hub, err := require.Hub(cfg, remote, log.StandardLogger()) + if err != nil { + return err + } + + for _, name := range args { + item := hub.GetItem(cli.name, name) + if item == nil { + return fmt.Errorf("can't find '%s' in %s", name, cli.name) + } + + if diff { + fmt.Println(cli.whyTainted(ctx, hub, item, rev)) + + continue + } + + if err = clihub.InspectItem(item, !noMetrics, cfg.Cscli.Output, cfg.Cscli.PrometheusUrl, cfg.Cscli.Color); err != nil { + return err + } + + if cli.inspectDetail != nil { + if err = cli.inspectDetail(item); err != nil { + return err + } + } + } + + return nil +} + +func (cli cliItem) newInspectCmd() *cobra.Command { + var ( + url string + diff bool + rev bool + noMetrics bool + ) + + cmd := &cobra.Command{ + Use: cmp.Or(cli.inspectHelp.use, "inspect [item]..."), + Short: cmp.Or(cli.inspectHelp.short, "Inspect given "+cli.oneOrMore), + Long: cmp.Or(cli.inspectHelp.long, "Inspect the state of one or more "+cli.name), + Example: cli.inspectHelp.example, + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return compInstalledItems(cli.name, args, toComplete, cli.cfg) + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.inspect(cmd.Context(), args, url, diff, rev, noMetrics) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Prometheus url") + flags.BoolVar(&diff, "diff", false, "Show diff with latest version (for tainted items)") + flags.BoolVar(&rev, "rev", false, "Reverse diff output") + flags.BoolVar(&noMetrics, "no-metrics", false, "Don't show metrics (when cscli.output=human)") + + return cmd +} + +func (cli cliItem) list(args []string, all bool) error { + cfg := cli.cfg() + + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + items := make(map[string][]*cwhub.Item) + + items[cli.name], err = clihub.SelectItems(hub, cli.name, args, !all) + if err != nil { + return err + } + + return clihub.ListItems(color.Output, cfg.Cscli.Color, []string{cli.name}, items, false, cfg.Cscli.Output) +} + +func (cli cliItem) newListCmd() *cobra.Command { + var all bool + + cmd := &cobra.Command{ + Use: cmp.Or(cli.listHelp.use, "list [item... | -a]"), + Short: cmp.Or(cli.listHelp.short, "List "+cli.oneOrMore), + Long: cmp.Or(cli.listHelp.long, "List of installed/available/specified "+cli.name), + Example: cli.listHelp.example, + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.list(args, all) + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&all, "all", "a", false, "List disabled items as well") + + return cmd +} + +// return the diff between the installed version and the latest version +func (cli cliItem) itemDiff(ctx context.Context, item *cwhub.Item, reverse bool) (string, error) { + if !item.State.Installed { + return "", fmt.Errorf("'%s' is not installed", item.FQName()) + } + + dest, err := os.CreateTemp("", "cscli-diff-*") + if err != nil { + return "", fmt.Errorf("while creating temporary file: %w", err) + } + defer os.Remove(dest.Name()) + + _, remoteURL, err := item.FetchContentTo(ctx, dest.Name()) + if err != nil { + return "", err + } + + latestContent, err := os.ReadFile(dest.Name()) + if err != nil { + return "", fmt.Errorf("while reading %s: %w", dest.Name(), err) + } + + localContent, err := os.ReadFile(item.State.LocalPath) + if err != nil { + return "", fmt.Errorf("while reading %s: %w", item.State.LocalPath, err) + } + + file1 := item.State.LocalPath + file2 := remoteURL + content1 := string(localContent) + content2 := string(latestContent) + + if reverse { + file1, file2 = file2, file1 + content1, content2 = content2, content1 + } + + edits := myers.ComputeEdits(span.URIFromPath(file1), content1, content2) + diff := gotextdiff.ToUnified(file1, file2, content1, edits) + + return fmt.Sprintf("%s", diff), nil +} + +func (cli cliItem) whyTainted(ctx context.Context, hub *cwhub.Hub, item *cwhub.Item, reverse bool) string { + if !item.State.Installed { + return fmt.Sprintf("# %s is not installed", item.FQName()) + } + + if !item.State.Tainted { + return fmt.Sprintf("# %s is not tainted", item.FQName()) + } + + if len(item.State.TaintedBy) == 0 { + return fmt.Sprintf("# %s is tainted but we don't know why. please report this as a bug", item.FQName()) + } + + ret := []string{ + fmt.Sprintf("# Let's see why %s is tainted.", item.FQName()), + } + + for _, fqsub := range item.State.TaintedBy { + ret = append(ret, fmt.Sprintf("\n-> %s\n", fqsub)) + + sub, err := hub.GetItemFQ(fqsub) + if err != nil { + ret = append(ret, err.Error()) + } + + diff, err := cli.itemDiff(ctx, sub, reverse) + if err != nil { + ret = append(ret, err.Error()) + } + + if diff != "" { + ret = append(ret, diff) + } else if len(sub.State.TaintedBy) > 0 { + taintList := strings.Join(sub.State.TaintedBy, ", ") + if sub.FQName() == taintList { + // hack: avoid message "item is tainted by itself" + continue + } + + ret = append(ret, fmt.Sprintf("# %s is tainted by %s", sub.FQName(), taintList)) + } + } + + return strings.Join(ret, "\n") +} diff --git a/cmd/crowdsec-cli/hubparser.go b/cmd/crowdsec-cli/cliitem/parser.go similarity index 93% rename from cmd/crowdsec-cli/hubparser.go rename to cmd/crowdsec-cli/cliitem/parser.go index 0b224c8a7f6..bc1d96bdaf0 100644 --- a/cmd/crowdsec-cli/hubparser.go +++ b/cmd/crowdsec-cli/cliitem/parser.go @@ -1,11 +1,12 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIParser() *cliItem { +func NewParser(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.PARSERS, singular: "parser", oneOrMore: "parser(s)", diff --git a/cmd/crowdsec-cli/hubpostoverflow.go b/cmd/crowdsec-cli/cliitem/postoverflow.go similarity index 93% rename from cmd/crowdsec-cli/hubpostoverflow.go rename to cmd/crowdsec-cli/cliitem/postoverflow.go index 908ccbea0fd..ea53aef327d 100644 --- a/cmd/crowdsec-cli/hubpostoverflow.go +++ b/cmd/crowdsec-cli/cliitem/postoverflow.go @@ -1,11 +1,12 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIPostOverflow() *cliItem { +func NewPostOverflow(cfg configGetter) *cliItem { return &cliItem{ + cfg: cfg, name: cwhub.POSTOVERFLOWS, singular: "postoverflow", oneOrMore: "postoverflow(s)", diff --git a/cmd/crowdsec-cli/item_suggest.go b/cmd/crowdsec-cli/cliitem/suggest.go similarity index 68% rename from cmd/crowdsec-cli/item_suggest.go rename to cmd/crowdsec-cli/cliitem/suggest.go index d3beee72100..5b080722af9 100644 --- a/cmd/crowdsec-cli/item_suggest.go +++ b/cmd/crowdsec-cli/cliitem/suggest.go @@ -1,4 +1,4 @@ -package main +package cliitem import ( "fmt" @@ -19,7 +19,7 @@ func suggestNearestMessage(hub *cwhub.Hub, itemType string, itemName string) str score := 100 nearest := "" - for _, item := range hub.GetItemMap(itemType) { + for _, item := range hub.GetItemsByType(itemType, false) { d := levenshtein.Distance(itemName, item.Name, nil) if d < score { score = d @@ -36,15 +36,15 @@ func suggestNearestMessage(hub *cwhub.Hub, itemType string, itemName string) str return msg } -func compAllItems(itemType string, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - hub, err := require.Hub(csConfig, nil, nil) +func compAllItems(itemType string, args []string, toComplete string, cfg configGetter) ([]string, cobra.ShellCompDirective) { + hub, err := require.Hub(cfg(), nil, nil) if err != nil { return nil, cobra.ShellCompDirectiveDefault } comp := make([]string, 0) - for _, item := range hub.GetItemMap(itemType) { + for _, item := range hub.GetItemsByType(itemType, false) { if !slices.Contains(args, item.Name) && strings.Contains(item.Name, toComplete) { comp = append(comp, item.Name) } @@ -55,28 +55,20 @@ func compAllItems(itemType string, args []string, toComplete string) ([]string, return comp, cobra.ShellCompDirectiveNoFileComp } -func compInstalledItems(itemType string, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - hub, err := require.Hub(csConfig, nil, nil) +func compInstalledItems(itemType string, args []string, toComplete string, cfg configGetter) ([]string, cobra.ShellCompDirective) { + hub, err := require.Hub(cfg(), nil, nil) if err != nil { return nil, cobra.ShellCompDirectiveDefault } - items, err := hub.GetInstalledItemNames(itemType) - if err != nil { - cobra.CompDebugln(fmt.Sprintf("list installed %s err: %s", itemType, err), true) - return nil, cobra.ShellCompDirectiveDefault - } + items := hub.GetInstalledByType(itemType, true) comp := make([]string, 0) - if toComplete != "" { - for _, item := range items { - if strings.Contains(item, toComplete) { - comp = append(comp, item) - } + for _, item := range items { + if strings.Contains(item.Name, toComplete) { + comp = append(comp, item.Name) } - } else { - comp = items } cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) diff --git a/cmd/crowdsec-cli/lapi.go b/cmd/crowdsec-cli/clilapi/lapi.go similarity index 57% rename from cmd/crowdsec-cli/lapi.go rename to cmd/crowdsec-cli/clilapi/lapi.go index ce59ac370cd..bb721eefe03 100644 --- a/cmd/crowdsec-cli/lapi.go +++ b/cmd/crowdsec-cli/clilapi/lapi.go @@ -1,22 +1,24 @@ -package main +package clilapi import ( "context" "errors" "fmt" + "io" "net/url" "os" + "slices" "sort" "strings" + "github.com/fatih/color" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - "slices" - - "github.com/crowdsecurity/go-cs-lib/version" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/apiclient" @@ -29,100 +31,96 @@ import ( const LAPIURLPrefix = "v1" -func runLapiStatus(cmd *cobra.Command, args []string) error { - password := strfmt.Password(csConfig.API.Client.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL) - login := csConfig.API.Client.Credentials.Login - if err != nil { - return fmt.Errorf("parsing api url: %w", err) - } +type configGetter = func() *csconfig.Config - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - return err +type cliLapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliLapi { + return &cliLapi{ + cfg: cfg, } +} - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) +// queryLAPIStatus checks if the Local API is reachable, and if the credentials are correct. +func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, error) { + apiURL, err := url.Parse(credURL) if err != nil { - return fmt.Errorf("failed to get scenarios: %w", err) + return false, err } - Client, err = apiclient.NewDefaultClient(apiurl, + client, err := apiclient.NewDefaultClient(apiURL, LAPIURLPrefix, - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil) if err != nil { - return fmt.Errorf("init default client: %w", err) + return false, err } + + pw := strfmt.Password(password) + + itemsForAPI := hub.GetInstalledListForAPI() + t := models.WatcherAuthRequest{ MachineID: &login, - Password: &password, - Scenarios: scenarios, + Password: &pw, + Scenarios: itemsForAPI, } - log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", login, apiurl) - - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + _, _, err = client.Auth.AuthenticateWatcher(ctx, t) if err != nil { - return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err) + return false, err } - log.Infof("You can successfully interact with Local API (LAPI)") - return nil + return true, nil } -func runLapiRegister(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() +func (cli *cliLapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { + cfg := cli.cfg() - apiURL, err := flags.GetString("url") - if err != nil { - return err - } + cred := cfg.API.Client.Credentials - outputFile, err := flags.GetString("file") - if err != nil { - return err - } + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Client.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) - lapiUser, err := flags.GetString("machine") + _, err := queryLAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) if err != nil { - return err + return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err) } + fmt.Fprintf(out, "You can successfully interact with Local API (LAPI)\n") + + return nil +} + +func (cli *cliLapi) register(ctx context.Context, apiURL string, outputFile string, machine string, token string) error { + var err error + + lapiUser := machine + cfg := cli.cfg() + if lapiUser == "" { - lapiUser, err = generateID("") + lapiUser, err = idgen.GenerateMachineID("") if err != nil { return fmt.Errorf("unable to generate machine id: %w", err) } } - password := strfmt.Password(generatePassword(passwordLength)) - if apiURL == "" { - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil || csConfig.API.Client.Credentials.URL == "" { - return fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter") - } - apiURL = csConfig.API.Client.Credentials.URL - } - /*URL needs to end with /, but user doesn't care*/ - if !strings.HasSuffix(apiURL, "/") { - apiURL += "/" - } - /*URL needs to start with http://, but user doesn't care*/ - if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") { - apiURL = "http://" + apiURL - } - apiurl, err := url.Parse(apiURL) + + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) + + apiurl, err := prepareAPIURL(cfg.API.Client, apiURL) if err != nil { return fmt.Errorf("parsing api url: %w", err) } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: lapiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: LAPIURLPrefix, - }, nil) + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ + MachineID: lapiUser, + Password: password, + RegistrationToken: token, + URL: apiurl, + VersionPrefix: LAPIURLPrefix, + }, nil) if err != nil { return fmt.Errorf("api client register: %w", err) } @@ -130,138 +128,170 @@ func runLapiRegister(cmd *cobra.Command, args []string) error { log.Printf("Successfully registered to Local API (LAPI)") var dumpFile string + if outputFile != "" { dumpFile = outputFile - } else if csConfig.API.Client.CredentialsFilePath != "" { - dumpFile = csConfig.API.Client.CredentialsFilePath + } else if cfg.API.Client.CredentialsFilePath != "" { + dumpFile = cfg.API.Client.CredentialsFilePath } else { dumpFile = "" } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: lapiUser, - Password: password.String(), - URL: apiURL, + + apiCfg := cfg.API.Client.Credentials + apiCfg.Login = lapiUser + apiCfg.Password = password.String() + + if apiURL != "" { + apiCfg.URL = apiURL } + apiConfigDump, err := yaml.Marshal(apiCfg) if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) + return fmt.Errorf("unable to serialize api credentials: %w", err) } + if dumpFile != "" { err = os.WriteFile(dumpFile, apiConfigDump, 0o600) if err != nil { return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err) } + log.Printf("Local API credentials written to '%s'", dumpFile) } else { fmt.Printf("%s\n", string(apiConfigDump)) } - log.Warning(ReloadMessage()) + + log.Warning(reload.Message) return nil } -func NewLapiStatusCmd() *cobra.Command { +// prepareAPIURL checks/fixes a LAPI connection url (http, https or socket) and returns an URL struct +func prepareAPIURL(clientCfg *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) { + if apiURL == "" { + if clientCfg == nil || clientCfg.Credentials == nil || clientCfg.Credentials.URL == "" { + return nil, errors.New("no Local API URL. Please provide it in your configuration or with the -u parameter") + } + + apiURL = clientCfg.Credentials.URL + } + + // URL needs to end with /, but user doesn't care + if !strings.HasSuffix(apiURL, "/") { + apiURL += "/" + } + + // URL needs to start with http://, but user doesn't care + if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") { + apiURL = "http://" + apiURL + } + + return url.Parse(apiURL) +} + +func (cli *cliLapi) newStatusCmd() *cobra.Command { cmdLapiStatus := &cobra.Command{ Use: "status", Short: "Check authentication to Local API (LAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: runLapiStatus, + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) + }, } return cmdLapiStatus } -func NewLapiRegisterCmd() *cobra.Command { - cmdLapiRegister := &cobra.Command{ +func (cli *cliLapi) newRegisterCmd() *cobra.Command { + var ( + apiURL string + outputFile string + machine string + token string + ) + + cmd := &cobra.Command{ Use: "register", Short: "Register a machine to Local API (LAPI)", Long: `Register your machine to the Local API (LAPI). Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: runLapiRegister, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), apiURL, outputFile, machine, token) + }, } - flags := cmdLapiRegister.Flags() - flags.StringP("url", "u", "", "URL of the API (ie. http://127.0.0.1)") - flags.StringP("file", "f", "", "output file destination") - flags.String("machine", "", "Name of the machine to register with") + flags := cmd.Flags() + flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)") + flags.StringVarP(&outputFile, "file", "f", "", "output file destination") + flags.StringVar(&machine, "machine", "", "Name of the machine to register with") + flags.StringVar(&token, "token", "", "Auto registration token to use") - return cmdLapiRegister + return cmd } -func NewLapiCmd() *cobra.Command { - cmdLapi := &cobra.Command{ +func (cli *cliLapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "lapi [action]", Short: "Manage interaction with Local API (LAPI)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadAPIClient(); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { return fmt.Errorf("loading api client: %w", err) } return nil }, } - cmdLapi.AddCommand(NewLapiRegisterCmd()) - cmdLapi.AddCommand(NewLapiStatusCmd()) - cmdLapi.AddCommand(NewLapiContextCmd()) + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newContextCmd()) - return cmdLapi + return cmd } -func AddContext(key string, values []string) error { +func (cli *cliLapi) addContext(key string, values []string) error { + cfg := cli.cfg() + if err := alertcontext.ValidateContextExpr(key, values); err != nil { - return fmt.Errorf("invalid context configuration :%s", err) + return fmt.Errorf("invalid context configuration: %w", err) } - if _, ok := csConfig.Crowdsec.ContextToSend[key]; !ok { - csConfig.Crowdsec.ContextToSend[key] = make([]string, 0) + + if _, ok := cfg.Crowdsec.ContextToSend[key]; !ok { + cfg.Crowdsec.ContextToSend[key] = make([]string, 0) log.Infof("key '%s' added", key) } - data := csConfig.Crowdsec.ContextToSend[key] + + data := cfg.Crowdsec.ContextToSend[key] + for _, val := range values { if !slices.Contains(data, val) { log.Infof("value '%s' added to key '%s'", val, key) data = append(data, val) } - csConfig.Crowdsec.ContextToSend[key] = data - } - if err := csConfig.Crowdsec.DumpContextConfigFile(); err != nil { - return err + + cfg.Crowdsec.ContextToSend[key] = data } - return nil + return cfg.Crowdsec.DumpContextConfigFile() } -func NewLapiContextCmd() *cobra.Command { - cmdContext := &cobra.Command{ - Use: "context [command]", - Short: "Manage context to send with alerts", - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadCrowdsec(); err != nil { - fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", csConfig.Crowdsec.ConsoleContextPath) - if err.Error() != fileNotFoundMessage { - return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err) - } - } - if csConfig.DisableAgent { - return errors.New("agent is disabled and lapi context can only be used on the agent") - } - - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - printHelp(cmd) - }, - } +func (cli *cliLapi) newContextAddCmd() *cobra.Command { + var ( + keyToAdd string + valuesToAdd []string + ) - var keyToAdd string - var valuesToAdd []string - cmdContextAdd := &cobra.Command{ + cmd := &cobra.Command{ Use: "add", Short: "Add context to send with alerts. You must specify the output key with the expr value you want", Example: `cscli lapi context add --key source_ip --value evt.Meta.source_ip @@ -269,28 +299,25 @@ cscli lapi context add --key file_source --value evt.Line.Src cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user `, DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - hub, err := require.Hub(csConfig, nil, nil) + RunE: func(_ *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) if err != nil { return err } - if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil { + if err = alertcontext.LoadConsoleContext(cli.cfg(), hub); err != nil { return fmt.Errorf("while loading context: %w", err) } if keyToAdd != "" { - if err := AddContext(keyToAdd, valuesToAdd); err != nil { - return err - } - return nil + return cli.addContext(keyToAdd, valuesToAdd) } for _, v := range valuesToAdd { keySlice := strings.Split(v, ".") key := keySlice[len(keySlice)-1] value := []string{v} - if err := AddContext(key, value); err != nil { + if err := cli.addContext(key, value); err != nil { return err } } @@ -298,31 +325,38 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user return nil }, } - cmdContextAdd.Flags().StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") - cmdContextAdd.Flags().StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") - cmdContextAdd.MarkFlagRequired("value") - cmdContext.AddCommand(cmdContextAdd) - cmdContextStatus := &cobra.Command{ + flags := cmd.Flags() + flags.StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send") + flags.StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key") + + _ = cmd.MarkFlagRequired("value") + + return cmd +} + +func (cli *cliLapi) newContextStatusCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "status", Short: "List context to send with alerts", DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - hub, err := require.Hub(csConfig, nil, nil) + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } - if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil { + if err = alertcontext.LoadConsoleContext(cfg, hub); err != nil { return fmt.Errorf("while loading context: %w", err) } - if len(csConfig.Crowdsec.ContextToSend) == 0 { + if len(cfg.Crowdsec.ContextToSend) == 0 { fmt.Println("No context found on this agent. You can use 'cscli lapi context add' to add context to your alerts.") return nil } - dump, err := yaml.Marshal(csConfig.Crowdsec.ContextToSend) + dump, err := yaml.Marshal(cfg.Crowdsec.ContextToSend) if err != nil { return fmt.Errorf("unable to show context status: %w", err) } @@ -332,10 +366,14 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user return nil }, } - cmdContext.AddCommand(cmdContextStatus) + return cmd +} + +func (cli *cliLapi) newContextDetectCmd() *cobra.Command { var detectAll bool - cmdContextDetect := &cobra.Command{ + + cmd := &cobra.Command{ Use: "detect", Short: "Detect available fields from the installed parsers", Example: `cscli lapi context detect --all @@ -343,9 +381,10 @@ cscli lapi context detect crowdsecurity/sshd-logs `, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() if !detectAll && len(args) == 0 { - log.Infof("Please provide parsers to detect or --all flag.") - printHelp(cmd) + _ = cmd.Help() + return errors.New("please provide parsers to detect or --all flag") } // to avoid all the log.Info from the loaders functions @@ -355,13 +394,13 @@ cscli lapi context detect crowdsecurity/sshd-logs return fmt.Errorf("failed to init expr helpers: %w", err) } - hub, err := require.Hub(csConfig, nil, nil) + hub, err := require.Hub(cfg, nil, nil) if err != nil { return err } csParsers := parser.NewParsers(hub) - if csParsers, err = parser.LoadParsers(csConfig, csParsers); err != nil { + if csParsers, err = parser.LoadParsers(cfg, csParsers); err != nil { return fmt.Errorf("unable to load parsers: %w", err) } @@ -418,47 +457,81 @@ cscli lapi context detect crowdsecurity/sshd-logs return nil }, } - cmdContextDetect.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") - cmdContext.AddCommand(cmdContextDetect) + cmd.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser") - cmdContextDelete := &cobra.Command{ + return cmd +} + +func (cli *cliLapi) newContextDeleteCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "delete", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - filePath := csConfig.Crowdsec.ConsoleContextPath + filePath := cli.cfg().Crowdsec.ConsoleContextPath if filePath == "" { filePath = "the context file" } - fmt.Printf("Command \"delete\" is deprecated, please manually edit %s.", filePath) + + return fmt.Errorf("command 'delete' has been removed, please manually edit %s", filePath) + }, + } + + return cmd +} + +func (cli *cliLapi) newContextCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "context [command]", + Short: "Manage context to send with alerts", + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadCrowdsec(); err != nil { + fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", cfg.Crowdsec.ConsoleContextPath) + if err.Error() != fileNotFoundMessage { + return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err) + } + } + if cfg.DisableAgent { + return errors.New("agent is disabled and lapi context can only be used on the agent") + } + return nil }, } - cmdContext.AddCommand(cmdContextDelete) - return cmdContext + cmd.AddCommand(cli.newContextAddCmd()) + cmd.AddCommand(cli.newContextStatusCmd()) + cmd.AddCommand(cli.newContextDetectCmd()) + cmd.AddCommand(cli.newContextDeleteCmd()) + + return cmd } -func detectStaticField(GrokStatics []parser.ExtraField) []string { +func detectStaticField(grokStatics []parser.ExtraField) []string { ret := make([]string, 0) - for _, static := range GrokStatics { + for _, static := range grokStatics { if static.Parsed != "" { - fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed) + fieldName := "evt.Parsed." + static.Parsed if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } + if static.Meta != "" { - fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta) + fieldName := "evt.Meta." + static.Meta if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } + if static.TargetByName != "" { fieldName := static.TargetByName if !strings.HasPrefix(fieldName, "evt.") { fieldName = "evt." + fieldName } + if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -473,7 +546,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { if node.Grok.RunTimeRegexp != nil { for _, capturedField := range node.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -485,7 +558,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { // ignore error (parser does not exist?) if err == nil { for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -515,23 +588,24 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { } func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { - var ret = make([]string, 0) + ret := make([]string, 0) for _, subnode := range node.LeavesNodes { if subnode.Grok.RunTimeRegexp != nil { for _, capturedField := range subnode.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } } + if subnode.Grok.RegexpName != "" { grokCompiled, err := parserCTX.Grok.Get(subnode.Grok.RegexpName) if err == nil { // ignore error (parser does not exist?) for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } diff --git a/cmd/crowdsec-cli/clilapi/lapi_test.go b/cmd/crowdsec-cli/clilapi/lapi_test.go new file mode 100644 index 00000000000..caf986d847a --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/lapi_test.go @@ -0,0 +1,49 @@ +package clilapi + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +func TestPrepareAPIURL_NoProtocol(t *testing.T) { + url, err := prepareAPIURL(nil, "localhost:81") + require.NoError(t, err) + assert.Equal(t, "http://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_Http(t *testing.T) { + url, err := prepareAPIURL(nil, "http://localhost:81") + require.NoError(t, err) + assert.Equal(t, "http://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_Https(t *testing.T) { + url, err := prepareAPIURL(nil, "https://localhost:81") + require.NoError(t, err) + assert.Equal(t, "https://localhost:81/", url.String()) +} + +func TestPrepareAPIURL_UnixSocket(t *testing.T) { + url, err := prepareAPIURL(nil, "/path/socket") + require.NoError(t, err) + assert.Equal(t, "/path/socket/", url.String()) +} + +func TestPrepareAPIURL_Empty(t *testing.T) { + _, err := prepareAPIURL(nil, "") + require.Error(t, err) +} + +func TestPrepareAPIURL_Empty_ConfigOverride(t *testing.T) { + url, err := prepareAPIURL(&csconfig.LocalApiClientCfg{ + Credentials: &csconfig.ApiCredentialsCfg{ + URL: "localhost:80", + }, + }, "") + require.NoError(t, err) + assert.Equal(t, "http://localhost:80/", url.String()) +} diff --git a/cmd/crowdsec-cli/clilapi/utils.go b/cmd/crowdsec-cli/clilapi/utils.go new file mode 100644 index 00000000000..e3ec65f2145 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/utils.go @@ -0,0 +1,24 @@ +package clilapi + +func removeFromSlice(val string, slice []string) []string { + var i int + var value string + + valueFound := false + + // get the index + for i, value = range slice { + if value == val { + valueFound = true + break + } + } + + if valueFound { + slice[i] = slice[len(slice)-1] + slice[len(slice)-1] = "" + slice = slice[:len(slice)-1] + } + + return slice +} diff --git a/cmd/crowdsec-cli/climachine/flag.go b/cmd/crowdsec-cli/climachine/flag.go new file mode 100644 index 00000000000..c3fefd896e1 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/flag.go @@ -0,0 +1,29 @@ +package climachine + +// Custom types for flag validation and conversion. + +import ( + "errors" +) + +type MachinePassword string + +func (p *MachinePassword) String() string { + return string(*p) +} + +func (p *MachinePassword) Set(v string) error { + // a password can't be more than 72 characters + // due to bcrypt limitations + if len(v) > 72 { + return errors.New("password too long (max 72 characters)") + } + + *p = MachinePassword(v) + + return nil +} + +func (p *MachinePassword) Type() string { + return "string" +} diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go new file mode 100644 index 00000000000..1fbedcf57fd --- /dev/null +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -0,0 +1,717 @@ +package climachine + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "slices" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// getLastHeartbeat returns the last heartbeat timestamp of a machine +// and a boolean indicating if the machine is considered active or not. +func getLastHeartbeat(m *ent.Machine) (string, bool) { + if m.LastHeartbeat == nil { + return "-", false + } + + elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) + + hb := elapsed.Truncate(time.Second).String() + if elapsed > 2*time.Minute { + return hb, false + } + + return hb, true +} + +type configGetter = func() *csconfig.Config + +type cliMachines struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliMachines { + return &cliMachines{ + cfg: cfg, + } +} + +func (cli *cliMachines) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "machines [action]", + Short: "Manage local API machines [requires local API]", + Long: `To list/add/delete/validate/prune machines. +Note: This command requires database direct access, so is intended to be run on the local API machine. +`, + Example: `cscli machines [action]`, + DisableAutoGenTag: true, + Aliases: []string{"machine"}, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + if err = require.LAPI(cli.cfg()); err != nil { + return err + } + cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newValidateCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +func (cli *cliMachines) inspectHubHuman(out io.Writer, machine *ent.Machine) { + state := machine.Hubstate + + if len(state) == 0 { + fmt.Println("No hub items found for this machine") + return + } + + // group state rows by type for multiple tables + rowsByType := make(map[string][]table.Row) + + for itemType, items := range state { + for _, item := range items { + if _, ok := rowsByType[itemType]; !ok { + rowsByType[itemType] = make([]table.Row, 0) + } + + row := table.Row{item.Name, item.Status, item.Version} + rowsByType[itemType] = append(rowsByType[itemType], row) + } + } + + for itemType, rows := range rowsByType { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "Status", "Version"}) + t.SetTitle(itemType) + t.AppendRows(rows) + io.WriteString(out, t.Render()+"\n") + } +} + +func (cli *cliMachines) listHuman(out io.Writer, machines ent.Machines) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Last Update", "Status", "Version", "OS", "Auth Type", "Last Heartbeat"}) + + for _, m := range machines { + validated := emoji.Prohibited + if m.IsValidated { + validated = emoji.CheckMark + } + + hb, active := getLastHeartbeat(m) + if !active { + hb = emoji.Warning + " " + hb + } + + t.AppendRow(table.Row{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, clientinfo.GetOSNameAndVersion(m), m.AuthType, hb}) + } + + io.WriteString(out, t.Render()+"\n") +} + +// machineInfo contains only the data we want for inspect/list: no hub status, scenarios, edges, etc. +type machineInfo struct { + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + LastPush *time.Time `json:"last_push,omitempty"` + LastHeartbeat *time.Time `json:"last_heartbeat,omitempty"` + MachineId string `json:"machineId,omitempty"` + IpAddress string `json:"ipAddress,omitempty"` + Version string `json:"version,omitempty"` + IsValidated bool `json:"isValidated,omitempty"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` + Datasources map[string]int64 `json:"datasources,omitempty"` +} + +func newMachineInfo(m *ent.Machine) machineInfo { + return machineInfo{ + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + LastPush: m.LastPush, + LastHeartbeat: m.LastHeartbeat, + MachineId: m.MachineId, + IpAddress: m.IpAddress, + Version: m.Version, + IsValidated: m.IsValidated, + AuthType: m.AuthType, + OS: clientinfo.GetOSNameAndVersion(m), + Featureflags: clientinfo.GetFeatureFlagList(m), + Datasources: m.Datasources, + } +} + +func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat", "os"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + for _, m := range machines { + validated := "false" + if m.IsValidated { + validated = "true" + } + + hb := "-" + if m.LastHeartbeat != nil { + hb = m.LastHeartbeat.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb, fmt.Sprintf("%s/%s", m.Osname, m.Osversion)}); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + machines, err := db.ListMachines(ctx) + if err != nil { + return fmt.Errorf("unable to list machines: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, machines) + case "json": + info := make([]machineInfo, 0, len(machines)) + for _, m := range machines { + info = append(info, newMachineInfo(m)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, machines) + } + + return nil +} + +func (cli *cliMachines) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all machines in the database", + Long: `list all machines in the database with their status and last heartbeat`, + Example: `cscli machines list`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} + +func (cli *cliMachines) newAddCmd() *cobra.Command { + var ( + password MachinePassword + dumpFile string + apiURL string + interactive bool + autoAdd bool + force bool + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "add a single machine to the database", + DisableAutoGenTag: true, + Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, + Example: `cscli machines add --auto +cscli machines add MyTestMachine --auto +cscli machines add MyTestMachine --password MyPassword +cscli machines add -f- --auto > /tmp/mycreds.yaml`, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + }, + } + + flags := cmd.Flags() + flags.VarP(&password, "password", "p", "machine password to login to the API") + flags.StringVarP(&dumpFile, "file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") + flags.StringVarP(&apiURL, "url", "u", "", "URL of the local API") + flags.BoolVarP(&interactive, "interactive", "i", false, "interfactive mode to enter the password") + flags.BoolVarP(&autoAdd, "auto", "a", false, "automatically generate password (and username if not provided)") + flags.BoolVar(&force, "force", false, "will force add the machine if it already exist") + + return cmd +} + +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { + var ( + err error + machineID string + ) + + // create machineID if not specified by user + if len(args) == 0 { + if !autoAdd { + return errors.New("please specify a machine name to add, or use --auto") + } + + machineID, err = idgen.GenerateMachineID("") + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + } else { + machineID = args[0] + } + + clientCfg := cli.cfg().API.Client + serverCfg := cli.cfg().API.Server + + /*check if file already exists*/ + if dumpFile == "" && clientCfg != nil && clientCfg.CredentialsFilePath != "" { + credFile := clientCfg.CredentialsFilePath + // use the default only if the file does not exist + _, err = os.Stat(credFile) + + switch { + case os.IsNotExist(err) || force: + dumpFile = credFile + case err != nil: + return fmt.Errorf("unable to stat '%s': %w", credFile, err) + default: + return fmt.Errorf(`credentials file '%s' already exists: please remove it, use "--force" or specify a different file with "-f" ("-f -" for standard output)`, credFile) + } + } + + if dumpFile == "" { + return errors.New(`please specify a file to dump credentials to, with -f ("-f -" for standard output)`) + } + + // create a password if it's not specified by user + if machinePassword == "" && !interactive { + if !autoAdd { + return errors.New("please specify a password with --password or use --auto") + } + + machinePassword = idgen.GeneratePassword(idgen.PasswordLength) + } else if machinePassword == "" && interactive { + qs := &survey.Password{ + Message: "Please provide a password for the machine:", + } + survey.AskOne(qs, &machinePassword) + } + + password := strfmt.Password(machinePassword) + + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) + if err != nil { + return fmt.Errorf("unable to create machine: %w", err) + } + + fmt.Fprintf(os.Stderr, "Machine '%s' successfully added to the local API.\n", machineID) + + if apiURL == "" { + if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" { + apiURL = clientCfg.Credentials.URL + } else if serverCfg.ClientURL() != "" { + apiURL = serverCfg.ClientURL() + } else { + return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") + } + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: machineID, + Password: password.String(), + URL: apiURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" && dumpFile != "-" { + if err = os.WriteFile(dumpFile, apiConfigDump, 0o600); err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + fmt.Fprintf(os.Stderr, "API credentials written to '%s'.\n", dumpFile) + } else { + fmt.Print(string(apiConfigDump)) + } + + return nil +} + +// validMachineID returns a list of machine IDs for command completion +func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + machines, err := cli.db.ListMachines(ctx) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, machine := range machines { + if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { + ret = append(ret, machine.MachineId) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error { + for _, machineID := range machines { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { + var notFoundErr *database.MachineNotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + return nil + } + + log.Errorf("unable to delete machine: %s", err) + + return nil + } + + log.Infof("machine '%s' deleted successfully", machineID) + } + + return nil +} + +func (cli *cliMachines) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete [machine_name]...", + Short: "delete machine(s) by name", + Example: `cscli machines delete "machine1" "machine2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more machines don't exist") + + return cmd +} + +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { + if duration < 2*time.Minute && !notValidOnly { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This can break installations if the machines are only temporarily disconnected. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + machines := []*ent.Machine{} + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { + machines = append(machines, pending...) + } + + if !notValidOnly { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { + machines = append(machines, pending...) + } + } + + if len(machines) == 0 { + fmt.Println("No machines to prune.") + return nil + } + + cli.listHuman(color.Output, machines) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above machines from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) + if err != nil { + return fmt.Errorf("unable to prune machines: %w", err) + } + + fmt.Fprintf(os.Stderr, "successfully deleted %d machines\n", deleted) + + return nil +} + +func (cli *cliMachines) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + notValidOnly bool + force bool + ) + + const defaultDuration = 10 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple machines from the database", + Long: `prune multiple machines that are not validated or have not connected to the local API in a given duration.`, + Example: `cscli machines prune +cscli machines prune --duration 1h +cscli machines prune --not-validated-only --force`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since validated machine last heartbeat") + flags.BoolVar(¬ValidOnly, "not-validated-only", false, "only prune machines that are not validated") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} + +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { + return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) + } + + log.Infof("machine '%s' validated successfully", machineID) + + return nil +} + +func (cli *cliMachines) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate", + Short: "validate a machine to access the local API", + Long: `validate a machine to access the local API.`, + Example: `cscli machines validate "machine_name"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) + }, + } + + return cmd +} + +func (cli *cliMachines) inspectHuman(out io.Writer, machine *ent.Machine) { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Machine: " + machine.MachineId) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + t.AppendRows([]table.Row{ + {"IP Address", machine.IpAddress}, + {"Created At", machine.CreatedAt}, + {"Last Update", machine.UpdatedAt}, + {"Last Heartbeat", machine.LastHeartbeat}, + {"Validated?", machine.IsValidated}, + {"CrowdSec version", machine.Version}, + {"OS", clientinfo.GetOSNameAndVersion(machine)}, + {"Auth type", machine.AuthType}, + }) + + for dsName, dsCount := range machine.Datasources { + t.AppendRow(table.Row{"Datasources", fmt.Sprintf("%s: %d", dsName, dsCount)}) + } + + for _, ff := range clientinfo.GetFeatureFlagList(machine) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + for _, coll := range machine.Hubstate[cwhub.COLLECTIONS] { + t.AppendRow(table.Row{"Collections", coll.Name}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliMachines) inspect(machine *ent.Machine) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newMachineInfo(machine)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliMachines) inspectHub(machine *ent.Machine) error { + out := color.Output + + switch cli.cfg().Cscli.Output { + case "human": + cli.inspectHubHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(machine.Hubstate); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"type", "name", "status", "version"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + rows := make([][]string, 0) + + for itemType, items := range machine.Hubstate { + for _, item := range items { + rows = append(rows, []string{itemType, item.Name, item.Status, item.Version}) + } + } + + for _, row := range rows { + if err := csvwriter.Write(row); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + } + + return nil +} + +func (cli *cliMachines) newInspectCmd() *cobra.Command { + var showHub bool + + cmd := &cobra.Command{ + Use: "inspect [machine_name]", + Short: "inspect a machine by name", + Example: `cscli machines inspect "machine1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + machineID := args[0] + + machine, err := cli.db.QueryMachineByID(ctx, machineID) + if err != nil { + return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) + } + + if showHub { + return cli.inspectHub(machine) + } + + return cli.inspect(machine) + }, + } + + flags := cmd.Flags() + + flags.BoolVarP(&showHub, "hub", "H", false, "show hub state") + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/list.go b/cmd/crowdsec-cli/climetrics/list.go new file mode 100644 index 00000000000..ddb2baac14d --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/list.go @@ -0,0 +1,95 @@ +package climetrics + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +func (cli *cliMetrics) list() error { + type metricType struct { + Type string `json:"type" yaml:"type"` + Title string `json:"title" yaml:"title"` + Description string `json:"description" yaml:"description"` + } + + var allMetrics []metricType + + ms := NewMetricStore() + for _, section := range maptools.SortedKeys(ms) { + title, description := ms[section].Description() + allMetrics = append(allMetrics, metricType{ + Type: section, + Title: title, + Description: description, + }) + } + + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + out := color.Output + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Type", "Title", "Description"}) + t.SetColumnConfigs([]table.ColumnConfig{ + { + Name: "Type", + AlignHeader: text.AlignCenter, + }, + { + Name: "Title", + AlignHeader: text.AlignCenter, + }, + { + Name: "Description", + AlignHeader: text.AlignCenter, + WidthMax: 60, + WidthMaxEnforcer: text.WrapSoft, + }, + }) + + t.Style().Options.SeparateRows = true + + for _, metric := range allMetrics { + t.AppendRow(table.Row{metric.Type, metric.Title, metric.Description}) + } + + io.WriteString(out, t.Render()+"\n") + case "json": + x, err := json.MarshalIndent(allMetrics, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize metric types: %w", err) + } + + fmt.Println(string(x)) + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliMetrics) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List available types of metrics.", + Long: `List available types of metrics.`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.list() + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/metrics.go b/cmd/crowdsec-cli/climetrics/metrics.go new file mode 100644 index 00000000000..f3bc4874460 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/metrics.go @@ -0,0 +1,54 @@ +package climetrics + +import ( + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +type configGetter func() *csconfig.Config + +type cliMetrics struct { + cfg configGetter +} + +func New(cfg configGetter) *cliMetrics { + return &cliMetrics{ + cfg: cfg, + } +} + +func (cli *cliMetrics) NewCommand() *cobra.Command { + var ( + url string + noUnit bool + ) + + cmd := &cobra.Command{ + Use: "metrics", + Short: "Display crowdsec prometheus metrics.", + Long: `Fetch metrics from a Local API server and display them`, + Example: `# Show all Metrics, skip empty tables (same as "cecli metrics show") +cscli metrics + +# Show only some metrics, connect to a different url +cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers + +# List available metric types +cscli metrics list`, + Args: cobra.ExactArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.show(cmd.Context(), nil, url, noUnit) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Prometheus url (http://:/metrics)") + flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newListCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/number.go b/cmd/crowdsec-cli/climetrics/number.go new file mode 100644 index 00000000000..709b7cf853a --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/number.go @@ -0,0 +1,45 @@ +package climetrics + +import ( + "fmt" + "math" + "strconv" +) + +type unit struct { + value int64 + symbol string +} + +var ranges = []unit{ + {value: 1e18, symbol: "E"}, + {value: 1e15, symbol: "P"}, + {value: 1e12, symbol: "T"}, + {value: 1e9, symbol: "G"}, + {value: 1e6, symbol: "M"}, + {value: 1e3, symbol: "k"}, + {value: 1, symbol: ""}, +} + +func formatNumber(num int64, withUnit bool) string { + if !withUnit { + return strconv.FormatInt(num, 10) + } + + goodUnit := ranges[len(ranges)-1] + + for _, u := range ranges { + if num >= u.value { + goodUnit = u + break + } + } + + if goodUnit.value == 1 { + return fmt.Sprintf("%d%s", num, goodUnit.symbol) + } + + res := math.Round(float64(num)/float64(goodUnit.value)*100) / 100 + + return fmt.Sprintf("%.2f%s", res, goodUnit.symbol) +} diff --git a/cmd/crowdsec-cli/climetrics/show.go b/cmd/crowdsec-cli/climetrics/show.go new file mode 100644 index 00000000000..045959048f6 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/show.go @@ -0,0 +1,113 @@ +package climetrics + +import ( + "context" + "errors" + "fmt" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" +) + +var ( + ErrMissingConfig = errors.New("prometheus section missing, can't show metrics") + ErrMetricsDisabled = errors.New("prometheus is not enabled, can't show metrics") +) + +func (cli *cliMetrics) show(ctx context.Context, sections []string, url string, noUnit bool) error { + cfg := cli.cfg() + + if url != "" { + cfg.Cscli.PrometheusUrl = url + } + + if cfg.Prometheus == nil { + return ErrMissingConfig + } + + if !cfg.Prometheus.Enabled { + return ErrMetricsDisabled + } + + ms := NewMetricStore() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + log.Warnf("unable to open database: %s", err) + } + + if err := ms.Fetch(ctx, cfg.Cscli.PrometheusUrl, db); err != nil { + log.Warn(err) + } + + // any section that we don't have in the store is an error + for _, section := range sections { + if _, ok := ms[section]; !ok { + return fmt.Errorf("unknown metrics type: %s", section) + } + } + + return ms.Format(color.Output, cfg.Cscli.Color, sections, cfg.Cscli.Output, noUnit) +} + +// expandAlias returns a list of sections. The input can be a list of sections or alias. +func expandAlias(args []string) []string { + ret := []string{} + + for _, section := range args { + switch section { + case "engine": + ret = append(ret, "acquisition", "parsers", "scenarios", "stash", "whitelists") + case "lapi": + ret = append(ret, "alerts", "decisions", "lapi", "lapi-bouncer", "lapi-decisions", "lapi-machine") + case "appsec": + ret = append(ret, "appsec-engine", "appsec-rule") + default: + ret = append(ret, section) + } + } + + return ret +} + +func (cli *cliMetrics) newShowCmd() *cobra.Command { + var ( + url string + noUnit bool + ) + + cmd := &cobra.Command{ + Use: "show [type]...", + Short: "Display all or part of the available metrics.", + Long: `Fetch metrics from a Local API server and display them, optionally filtering on specific types.`, + Example: `# Show all Metrics, skip empty tables +cscli metrics show + +# Use an alias: "engine", "lapi" or "appsec" to show a group of metrics +cscli metrics show engine + +# Show some specific metrics, show empty tables, connect to a different url +cscli metrics show acquisition parsers scenarios stash --url http://lapi.local:6060/metrics + +# To list available metric types, use "cscli metrics list" +cscli metrics list; cscli metrics list -o json + +# Show metrics in json format +cscli metrics show acquisition parsers scenarios stash -o json`, + // Positional args are optional + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + args = expandAlias(args) + return cli.show(cmd.Context(), args, url, noUnit) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&url, "url", "u", "", "Metrics url (http://:/metrics)") + flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units") + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/statacquis.go b/cmd/crowdsec-cli/climetrics/statacquis.go new file mode 100644 index 00000000000..0af2e796f40 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statacquis.go @@ -0,0 +1,44 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAcquis map[string]map[string]int + +func (s statAcquis) Description() (string, string) { + return "Acquisition Metrics", + `Measures the lines read, parsed, and unparsed per datasource. ` + + `Zero read lines indicate a misconfigured or inactive datasource. ` + + `Zero parsed lines means the parser(s) failed. ` + + `Non-zero parsed lines are fine as crowdsec selects relevant lines.` +} + +func (s statAcquis) Process(source, metric string, val int) { + if _, ok := s[source]; !ok { + s[source] = make(map[string]int) + } + + s[source][metric] += val +} + +func (s statAcquis) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket", "Lines whitelisted"}) + + keys := []string{"reads", "parsed", "unparsed", "pour", "whitelisted"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting acquis stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statalert.go b/cmd/crowdsec-cli/climetrics/statalert.go new file mode 100644 index 00000000000..942eceaa75c --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statalert.go @@ -0,0 +1,45 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAlert map[string]int + +func (s statAlert) Description() (string, string) { + return "Local API Alerts", + `Tracks the total number of past and present alerts for the installed scenarios.` +} + +func (s statAlert) Process(reason string, val int) { + s[reason] += val +} + +func (s statAlert) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Reason", "Count"}) + + numRows := 0 + + // TODO: sort keys + for scenario, hits := range s { + t.AppendRow(table.Row{ + scenario, + strconv.Itoa(hits), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statappsecengine.go b/cmd/crowdsec-cli/climetrics/statappsecengine.go new file mode 100644 index 00000000000..d924375247f --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statappsecengine.go @@ -0,0 +1,41 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAppsecEngine map[string]map[string]int + +func (s statAppsecEngine) Description() (string, string) { + return "Appsec Metrics", + `Measures the number of parsed and blocked requests by the AppSec Component.` +} + +func (s statAppsecEngine) Process(appsecEngine, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]int) + } + + s[appsecEngine][metric] += val +} + +func (s statAppsecEngine) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Appsec Engine", "Processed", "Blocked"}) + + keys := []string{"processed", "blocked"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting appsec stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statappsecrule.go b/cmd/crowdsec-cli/climetrics/statappsecrule.go new file mode 100644 index 00000000000..e06a7c2e2b3 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statappsecrule.go @@ -0,0 +1,48 @@ +package climetrics + +import ( + "fmt" + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statAppsecRule map[string]map[string]map[string]int + +func (s statAppsecRule) Description() (string, string) { + return "Appsec Rule Metrics", + `Provides “per AppSec Component” information about the number of matches for loaded AppSec Rules.` +} + +func (s statAppsecRule) Process(appsecEngine, appsecRule string, metric string, val int) { + if _, ok := s[appsecEngine]; !ok { + s[appsecEngine] = make(map[string]map[string]int) + } + + if _, ok := s[appsecEngine][appsecRule]; !ok { + s[appsecEngine][appsecRule] = make(map[string]int) + } + + s[appsecEngine][appsecRule][metric] += val +} + +func (s statAppsecRule) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + // TODO: sort keys + for appsecEngine, appsecEngineRulesStats := range s { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Rule ID", "Triggered"}) + + keys := []string{"triggered"} + + if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys, noUnit); err != nil { + log.Warningf("while collecting appsec rules stats: %s", err) + } else if numRows > 0 || showEmpty { + io.WriteString(out, fmt.Sprintf("Appsec '%s' Rules Metrics:\n", appsecEngine)) + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } + } +} diff --git a/cmd/crowdsec-cli/climetrics/statbouncer.go b/cmd/crowdsec-cli/climetrics/statbouncer.go new file mode 100644 index 00000000000..bc0da152d6d --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statbouncer.go @@ -0,0 +1,461 @@ +package climetrics + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sort" + "strings" + "time" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +// bouncerMetricItem represents unaggregated, denormalized metric data. +// Possibly not unique if a bouncer sent the same data multiple times. +type bouncerMetricItem struct { + collectedAt time.Time + bouncerName string + ipType string + origin string + name string + unit string + value float64 +} + +// aggregationOverTime is the first level of aggregation: we aggregate +// over time, then over ip type, then over origin. we only sum values +// for non-gauge metrics, and take the last value for gauge metrics. +type aggregationOverTime map[string]map[string]map[string]map[string]map[string]int64 + +func (a aggregationOverTime) add(bouncerName, origin, name, unit, ipType string, value float64, isGauge bool) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin]; !ok { + a[bouncerName][origin] = make(map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name]; !ok { + a[bouncerName][origin][name] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name][unit]; !ok { + a[bouncerName][origin][name][unit] = make(map[string]int64) + } + + if isGauge { + a[bouncerName][origin][name][unit][ipType] = int64(value) + } else { + a[bouncerName][origin][name][unit][ipType] += int64(value) + } +} + +// aggregationOverIPType is the second level of aggregation: data is summed +// regardless of the metrics type (gauge or not). This is used to display +// table rows, they won't differentiate ipv4 and ipv6 +type aggregationOverIPType map[string]map[string]map[string]map[string]int64 + +func (a aggregationOverIPType) add(bouncerName, origin, name, unit string, value int64) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin]; !ok { + a[bouncerName][origin] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][origin][name]; !ok { + a[bouncerName][origin][name] = make(map[string]int64) + } + + a[bouncerName][origin][name][unit] += value +} + +// aggregationOverOrigin is the third level of aggregation: these are +// the totals at the end of the table. Metrics without an origin will +// be added to the totals but not displayed in the rows, only in the footer. +type aggregationOverOrigin map[string]map[string]map[string]int64 + +func (a aggregationOverOrigin) add(bouncerName, name, unit string, value int64) { + if _, ok := a[bouncerName]; !ok { + a[bouncerName] = make(map[string]map[string]int64) + } + + if _, ok := a[bouncerName][name]; !ok { + a[bouncerName][name] = make(map[string]int64) + } + + a[bouncerName][name][unit] += value +} + +type statBouncer struct { + // oldest collection timestamp for each bouncer + oldestTS map[string]time.Time + // aggregate over ip type: always sum + // [bouncer][origin][name][unit]value + aggOverIPType aggregationOverIPType + // aggregate over origin: always sum + // [bouncer][name][unit]value + aggOverOrigin aggregationOverOrigin +} + +var knownPlurals = map[string]string{ + "byte": "bytes", + "packet": "packets", + "ip": "IPs", +} + +func (s *statBouncer) MarshalJSON() ([]byte, error) { + return json.Marshal(s.aggOverIPType) +} + +func (*statBouncer) Description() (string, string) { + return "Bouncer Metrics", + `Network traffic blocked by bouncers.` +} + +func logWarningOnce(warningsLogged map[string]bool, msg string) { + if _, ok := warningsLogged[msg]; !ok { + log.Warning(msg) + + warningsLogged[msg] = true + } +} + +// extractRawMetrics converts metrics from the database to a de-normalized, de-duplicated slice +// it returns the slice and the oldest timestamp for each bouncer +func (*statBouncer) extractRawMetrics(metrics []*ent.Metric) ([]bouncerMetricItem, map[string]time.Time) { + oldestTS := make(map[string]time.Time) + + // don't spam the user with the same warnings + warningsLogged := make(map[string]bool) + + // store raw metrics, de-duplicated in case some were sent multiple times + uniqueRaw := make(map[bouncerMetricItem]struct{}) + + for _, met := range metrics { + bouncerName := met.GeneratedBy + + var payload struct { + Metrics []models.DetailedMetrics `json:"metrics"` + } + + if err := json.Unmarshal([]byte(met.Payload), &payload); err != nil { + log.Warningf("while parsing metrics for %s: %s", bouncerName, err) + continue + } + + for _, m := range payload.Metrics { + // fields like timestamp, name, etc. are mandatory but we got pointers, so we check anyway + if m.Meta.UtcNowTimestamp == nil { + logWarningOnce(warningsLogged, "missing 'utc_now_timestamp' field in metrics reported by "+bouncerName) + continue + } + + collectedAt := time.Unix(*m.Meta.UtcNowTimestamp, 0).UTC() + + if oldestTS[bouncerName].IsZero() || collectedAt.Before(oldestTS[bouncerName]) { + oldestTS[bouncerName] = collectedAt + } + + for _, item := range m.Items { + valid := true + + if item.Name == nil { + logWarningOnce(warningsLogged, "missing 'name' field in metrics reported by "+bouncerName) + // no continue - keep checking the rest + valid = false + } + + if item.Unit == nil { + logWarningOnce(warningsLogged, "missing 'unit' field in metrics reported by "+bouncerName) + valid = false + } + + if item.Value == nil { + logWarningOnce(warningsLogged, "missing 'value' field in metrics reported by "+bouncerName) + valid = false + } + + if !valid { + continue + } + + rawMetric := bouncerMetricItem{ + collectedAt: collectedAt, + bouncerName: bouncerName, + ipType: item.Labels["ip_type"], + origin: item.Labels["origin"], + name: *item.Name, + unit: *item.Unit, + value: *item.Value, + } + + uniqueRaw[rawMetric] = struct{}{} + } + } + } + + // extract raw metric structs + keys := make([]bouncerMetricItem, 0, len(uniqueRaw)) + for key := range uniqueRaw { + keys = append(keys, key) + } + + // order them by timestamp + sort.Slice(keys, func(i, j int) bool { + return keys[i].collectedAt.Before(keys[j].collectedAt) + }) + + return keys, oldestTS +} + +func (s *statBouncer) Fetch(ctx context.Context, db *database.Client) error { + if db == nil { + return nil + } + + // query all bouncer metrics that have not been flushed + + metrics, err := db.Ent.Metric.Query(). + Where(metric.GeneratedTypeEQ(metric.GeneratedTypeRC)). + All(ctx) + if err != nil { + return fmt.Errorf("unable to fetch metrics: %w", err) + } + + // de-normalize, de-duplicate metrics and keep the oldest timestamp for each bouncer + + rawMetrics, oldestTS := s.extractRawMetrics(metrics) + + s.oldestTS = oldestTS + aggOverTime := s.newAggregationOverTime(rawMetrics) + s.aggOverIPType = s.newAggregationOverIPType(aggOverTime) + s.aggOverOrigin = s.newAggregationOverOrigin(s.aggOverIPType) + + return nil +} + +// return true if the metric is a gauge and should not be aggregated +func (*statBouncer) isGauge(name string) bool { + return name == "active_decisions" || strings.HasSuffix(name, "_gauge") +} + +// formatMetricName returns the metric name to display in the table header +func (*statBouncer) formatMetricName(name string) string { + return strings.TrimSuffix(name, "_gauge") +} + +// formatMetricOrigin returns the origin to display in the table rows +// (for example, some users don't know what capi is) +func (*statBouncer) formatMetricOrigin(origin string) string { + switch origin { + case "CAPI": + return origin + " (community blocklist)" + case "cscli": + return origin + " (manual decisions)" + case "crowdsec": + return origin + " (security engine)" + default: + return origin + } +} + +func (s *statBouncer) newAggregationOverTime(rawMetrics []bouncerMetricItem) aggregationOverTime { + ret := aggregationOverTime{} + + for _, raw := range rawMetrics { + ret.add(raw.bouncerName, raw.origin, raw.name, raw.unit, raw.ipType, raw.value, s.isGauge(raw.name)) + } + + return ret +} + +func (*statBouncer) newAggregationOverIPType(aggMetrics aggregationOverTime) aggregationOverIPType { + ret := aggregationOverIPType{} + + for bouncerName := range aggMetrics { + for origin := range aggMetrics[bouncerName] { + for name := range aggMetrics[bouncerName][origin] { + for unit := range aggMetrics[bouncerName][origin][name] { + for ipType := range aggMetrics[bouncerName][origin][name][unit] { + value := aggMetrics[bouncerName][origin][name][unit][ipType] + ret.add(bouncerName, origin, name, unit, value) + } + } + } + } + } + + return ret +} + +func (*statBouncer) newAggregationOverOrigin(aggMetrics aggregationOverIPType) aggregationOverOrigin { + ret := aggregationOverOrigin{} + + for bouncerName := range aggMetrics { + for origin := range aggMetrics[bouncerName] { + for name := range aggMetrics[bouncerName][origin] { + for unit := range aggMetrics[bouncerName][origin][name] { + val := aggMetrics[bouncerName][origin][name][unit] + ret.add(bouncerName, name, unit, val) + } + } + } + } + + return ret +} + +// bouncerTable displays a table of metrics for a single bouncer +func (s *statBouncer) bouncerTable(out io.Writer, bouncerName string, wantColor string, noUnit bool) { + columns := make(map[string]map[string]struct{}) + + bouncerData, ok := s.aggOverOrigin[bouncerName] + if !ok { + // no metrics for this bouncer, skip. how did we get here ? + // anyway we can't honor the "showEmpty" flag in this case, + // we don't even have the table headers + return + } + + for metricName, units := range bouncerData { + // build a map of the metric names and units, to display dynamic columns + columns[metricName] = make(map[string]struct{}) + for unit := range units { + columns[metricName][unit] = struct{}{} + } + } + + if len(columns) == 0 { + return + } + + t := cstable.New(out, wantColor).Writer + header1 := table.Row{"Origin"} + header2 := table.Row{""} + colNum := 1 + + colCfg := []table.ColumnConfig{{ + Number: colNum, + AlignHeader: text.AlignLeft, + Align: text.AlignLeft, + AlignFooter: text.AlignRight, + }} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + colNum += 1 + + header1 = append(header1, s.formatMetricName(name)) + + // we don't add "s" to random words + if plural, ok := knownPlurals[unit]; ok { + unit = plural + } + + header2 = append(header2, unit) + colCfg = append(colCfg, table.ColumnConfig{ + Number: colNum, + AlignHeader: text.AlignCenter, + Align: text.AlignRight, + AlignFooter: text.AlignRight, + }) + } + } + + t.AppendHeader(header1, table.RowConfig{AutoMerge: true}) + t.AppendHeader(header2) + + t.SetColumnConfigs(colCfg) + + numRows := 0 + + // sort all the ranges for stable output + + for _, origin := range maptools.SortedKeys(s.aggOverIPType[bouncerName]) { + if origin == "" { + // if the metric has no origin (i.e. processed bytes/packets) + // we don't display it in the table body but it still gets aggreagted + // in the footer's totals + continue + } + + metrics := s.aggOverIPType[bouncerName][origin] + + row := table.Row{s.formatMetricOrigin(origin)} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + valStr := "-" + + if val, ok := metrics[name][unit]; ok { + valStr = formatNumber(val, !noUnit) + } + + row = append(row, valStr) + } + } + + t.AppendRow(row) + + numRows += 1 + } + + totals := s.aggOverOrigin[bouncerName] + + if numRows == 0 { + t.Style().Options.SeparateFooter = false + } + + footer := table.Row{"Total"} + + for _, name := range maptools.SortedKeys(columns) { + for _, unit := range maptools.SortedKeys(columns[name]) { + footer = append(footer, formatNumber(totals[name][unit], !noUnit)) + } + } + + t.AppendFooter(footer) + + title, _ := s.Description() + title = fmt.Sprintf("%s (%s)", title, bouncerName) + + if s.oldestTS != nil { + // if you change this to .Local() beware of tests + title = fmt.Sprintf("%s since %s", title, s.oldestTS[bouncerName].String()) + } + + // don't use SetTitle() because it draws the title inside table box + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + // empty line between tables + io.WriteString(out, "\n") +} + +// Table displays a table of metrics for each bouncer +func (s *statBouncer) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + found := false + + for _, bouncerName := range maptools.SortedKeys(s.aggOverOrigin) { + s.bouncerTable(out, bouncerName, wantColor, noUnit) + found = true + } + + if !found && showEmpty { + io.WriteString(out, "No bouncer metrics found.\n\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statbucket.go b/cmd/crowdsec-cli/climetrics/statbucket.go new file mode 100644 index 00000000000..1882fe21df1 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statbucket.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statBucket map[string]map[string]int + +func (s statBucket) Description() (string, string) { + return "Scenario Metrics", + `Measure events in different scenarios. Current count is the number of buckets during metrics collection. ` + + `Overflows are past event-producing buckets, while Expired are the ones that didn’t receive enough events to Overflow.` +} + +func (s statBucket) Process(bucket, metric string, val int) { + if _, ok := s[bucket]; !ok { + s[bucket] = make(map[string]int) + } + + s[bucket][metric] += val +} + +func (s statBucket) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Scenario", "Current Count", "Overflows", "Instantiated", "Poured", "Expired"}) + + keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting scenario stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statdecision.go b/cmd/crowdsec-cli/climetrics/statdecision.go new file mode 100644 index 00000000000..b862f49ff12 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statdecision.go @@ -0,0 +1,60 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statDecision map[string]map[string]map[string]int + +func (s statDecision) Description() (string, string) { + return "Local API Decisions", + `Provides information about all currently active decisions. ` + + `Includes both local (crowdsec) and global decisions (CAPI), and lists subscriptions (lists).` +} + +func (s statDecision) Process(reason, origin, action string, val int) { + if _, ok := s[reason]; !ok { + s[reason] = make(map[string]map[string]int) + } + + if _, ok := s[reason][origin]; !ok { + s[reason][origin] = make(map[string]int) + } + + s[reason][origin][action] += val +} + +func (s statDecision) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Reason", "Origin", "Action", "Count"}) + + numRows := 0 + + // TODO: sort by reason, origin, action + for reason, origins := range s { + for origin, actions := range origins { + for action, hits := range actions { + t.AppendRow(table.Row{ + reason, + origin, + action, + strconv.Itoa(hits), + }) + + numRows++ + } + } + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapi.go b/cmd/crowdsec-cli/climetrics/statlapi.go new file mode 100644 index 00000000000..9559eacf0f4 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapi.go @@ -0,0 +1,56 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapi map[string]map[string]int + +func (s statLapi) Description() (string, string) { + return "Local API Metrics", + `Monitors the requests made to local API routes.` +} + +func (s statLapi) Process(route, method string, val int) { + if _, ok := s[route]; !ok { + s[route] = make(map[string]int) + } + + s[route][method] += val +} + +func (s statLapi) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Route", "Method", "Hits"}) + + // unfortunately, we can't reuse metricsToTable as the structure is too different :/ + numRows := 0 + + for _, alabel := range maptools.SortedKeys(s) { + astats := s[alabel] + + for _, sl := range maptools.SortedKeys(astats) { + t.AppendRow(table.Row{ + alabel, + sl, + strconv.Itoa(astats[sl]), + }) + + numRows++ + } + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapibouncer.go b/cmd/crowdsec-cli/climetrics/statlapibouncer.go new file mode 100644 index 00000000000..5e5f63a79d3 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapibouncer.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiBouncer map[string]map[string]map[string]int + +func (s statLapiBouncer) Description() (string, string) { + return "Local API Bouncers Metrics", + `Tracks total hits to remediation component related API routes.` +} + +func (s statLapiBouncer) Process(bouncer, route, method string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = make(map[string]map[string]int) + } + + if _, ok := s[bouncer][route]; !ok { + s[bouncer][route] = make(map[string]int) + } + + s[bouncer][route][method] += val +} + +func (s statLapiBouncer) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Bouncer", "Route", "Method", "Hits"}) + + numRows := lapiMetricsToTable(t, s) + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapidecision.go b/cmd/crowdsec-cli/climetrics/statlapidecision.go new file mode 100644 index 00000000000..44f0e8f4b87 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapidecision.go @@ -0,0 +1,64 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiDecision map[string]struct { + NonEmpty int + Empty int +} + +func (s statLapiDecision) Description() (string, string) { + return "Local API Bouncers Decisions", + `Tracks the number of empty/non-empty answers from LAPI to bouncers that are working in "live" mode.` +} + +func (s statLapiDecision) Process(bouncer, fam string, val int) { + if _, ok := s[bouncer]; !ok { + s[bouncer] = struct { + NonEmpty int + Empty int + }{} + } + + x := s[bouncer] + + switch fam { + case "cs_lapi_decisions_ko_total": + x.Empty += val + case "cs_lapi_decisions_ok_total": + x.NonEmpty += val + } + + s[bouncer] = x +} + +func (s statLapiDecision) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Bouncer", "Empty answers", "Non-empty answers"}) + + numRows := 0 + + for bouncer, hits := range s { + t.AppendRow(table.Row{ + bouncer, + strconv.Itoa(hits.Empty), + strconv.Itoa(hits.NonEmpty), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statlapimachine.go b/cmd/crowdsec-cli/climetrics/statlapimachine.go new file mode 100644 index 00000000000..0e6693bea82 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statlapimachine.go @@ -0,0 +1,42 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statLapiMachine map[string]map[string]map[string]int + +func (s statLapiMachine) Description() (string, string) { + return "Local API Machines Metrics", + `Tracks the number of calls to the local API from each registered machine.` +} + +func (s statLapiMachine) Process(machine, route, method string, val int) { + if _, ok := s[machine]; !ok { + s[machine] = make(map[string]map[string]int) + } + + if _, ok := s[machine][route]; !ok { + s[machine][route] = make(map[string]int) + } + + s[machine][route][method] += val +} + +func (s statLapiMachine) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Machine", "Route", "Method", "Hits"}) + + numRows := lapiMetricsToTable(t, s) + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statparser.go b/cmd/crowdsec-cli/climetrics/statparser.go new file mode 100644 index 00000000000..520e68f9adf --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statparser.go @@ -0,0 +1,43 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statParser map[string]map[string]int + +func (s statParser) Description() (string, string) { + return "Parser Metrics", + `Tracks the number of events processed by each parser and indicates success of failure. ` + + `Zero parsed lines means the parser(s) failed. ` + + `Non-zero unparsed lines are fine as crowdsec select relevant lines.` +} + +func (s statParser) Process(parser, metric string, val int) { + if _, ok := s[parser]; !ok { + s[parser] = make(map[string]int) + } + + s[parser][metric] += val +} + +func (s statParser) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Parsers", "Hits", "Parsed", "Unparsed"}) + + keys := []string{"hits", "parsed", "unparsed"} + + if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil { + log.Warningf("while collecting parsers stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statstash.go b/cmd/crowdsec-cli/climetrics/statstash.go new file mode 100644 index 00000000000..2729de931a1 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statstash.go @@ -0,0 +1,59 @@ +package climetrics + +import ( + "io" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statStash map[string]struct { + Type string + Count int +} + +func (s statStash) Description() (string, string) { + return "Parser Stash Metrics", + `Tracks the status of stashes that might be created by various parsers and scenarios.` +} + +func (s statStash) Process(name, mtype string, val int) { + s[name] = struct { + Type string + Count int + }{ + Type: mtype, + Count: val, + } +} + +func (s statStash) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Name", "Type", "Items"}) + + // unfortunately, we can't reuse metricsToTable as the structure is too different :/ + numRows := 0 + + for _, alabel := range maptools.SortedKeys(s) { + astats := s[alabel] + + t.AppendRow(table.Row{ + alabel, + astats.Type, + strconv.Itoa(astats.Count), + }) + + numRows++ + } + + if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/statwhitelist.go b/cmd/crowdsec-cli/climetrics/statwhitelist.go new file mode 100644 index 00000000000..7f533b45b4b --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/statwhitelist.go @@ -0,0 +1,43 @@ +package climetrics + +import ( + "io" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" +) + +type statWhitelist map[string]map[string]map[string]int + +func (s statWhitelist) Description() (string, string) { + return "Whitelist Metrics", + `Tracks the number of events processed and possibly whitelisted by each parser whitelist.` +} + +func (s statWhitelist) Process(whitelist, reason, metric string, val int) { + if _, ok := s[whitelist]; !ok { + s[whitelist] = make(map[string]map[string]int) + } + + if _, ok := s[whitelist][reason]; !ok { + s[whitelist][reason] = make(map[string]int) + } + + s[whitelist][reason][metric] += val +} + +func (s statWhitelist) Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) { + t := cstable.New(out, wantColor).Writer + t.AppendHeader(table.Row{"Whitelist", "Reason", "Hits", "Whitelisted"}) + + if numRows, err := wlMetricsToTable(t, s, noUnit); err != nil { + log.Warningf("while collecting parsers stats: %s", err) + } else if numRows > 0 || showEmpty { + title, _ := s.Description() + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") + io.WriteString(out, "\n") + } +} diff --git a/cmd/crowdsec-cli/climetrics/store.go b/cmd/crowdsec-cli/climetrics/store.go new file mode 100644 index 00000000000..55fab5dbd7f --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/store.go @@ -0,0 +1,271 @@ +package climetrics + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/prom2json" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +type metricSection interface { + Table(out io.Writer, wantColor string, noUnit bool, showEmpty bool) + Description() (string, string) +} + +type metricStore map[string]metricSection + +func NewMetricStore() metricStore { + return metricStore{ + "acquisition": statAcquis{}, + "alerts": statAlert{}, + "bouncers": &statBouncer{}, + "appsec-engine": statAppsecEngine{}, + "appsec-rule": statAppsecRule{}, + "decisions": statDecision{}, + "lapi": statLapi{}, + "lapi-bouncer": statLapiBouncer{}, + "lapi-decisions": statLapiDecision{}, + "lapi-machine": statLapiMachine{}, + "parsers": statParser{}, + "scenarios": statBucket{}, + "stash": statStash{}, + "whitelists": statWhitelist{}, + } +} + +func (ms metricStore) Fetch(ctx context.Context, url string, db *database.Client) error { + if err := ms["bouncers"].(*statBouncer).Fetch(ctx, db); err != nil { + return err + } + + return ms.fetchPrometheusMetrics(url) +} + +func (ms metricStore) fetchPrometheusMetrics(url string) error { + mfChan := make(chan *dto.MetricFamily, 1024) + errChan := make(chan error, 1) + + // Start with the DefaultTransport for sane defaults. + transport := http.DefaultTransport.(*http.Transport).Clone() + // Conservatively disable HTTP keep-alives as this program will only + // ever need a single HTTP request. + transport.DisableKeepAlives = true + // Timeout early if the server doesn't even return the headers. + transport.ResponseHeaderTimeout = time.Minute + go func() { + defer trace.CatchPanic("crowdsec/ShowPrometheus") + + err := prom2json.FetchMetricFamilies(url, mfChan, transport) + if err != nil { + errChan <- fmt.Errorf("while fetching metrics: %w", err) + return + } + errChan <- nil + }() + + result := []*prom2json.Family{} + for mf := range mfChan { + result = append(result, prom2json.NewFamily(mf)) + } + + if err := <-errChan; err != nil { + return err + } + + log.Debugf("Finished reading metrics output, %d entries", len(result)) + ms.processPrometheusMetrics(result) + + return nil +} + +func (ms metricStore) processPrometheusMetrics(result []*prom2json.Family) { + mAcquis := ms["acquisition"].(statAcquis) + mAlert := ms["alerts"].(statAlert) + mAppsecEngine := ms["appsec-engine"].(statAppsecEngine) + mAppsecRule := ms["appsec-rule"].(statAppsecRule) + mDecision := ms["decisions"].(statDecision) + mLapi := ms["lapi"].(statLapi) + mLapiBouncer := ms["lapi-bouncer"].(statLapiBouncer) + mLapiDecision := ms["lapi-decisions"].(statLapiDecision) + mLapiMachine := ms["lapi-machine"].(statLapiMachine) + mParser := ms["parsers"].(statParser) + mBucket := ms["scenarios"].(statBucket) + mStash := ms["stash"].(statStash) + mWhitelist := ms["whitelists"].(statWhitelist) + + for idx, fam := range result { + if !strings.HasPrefix(fam.Name, "cs_") { + continue + } + + log.Tracef("round %d", idx) + + for _, m := range fam.Metrics { + metric, ok := m.(prom2json.Metric) + if !ok { + log.Debugf("failed to convert metric to prom2json.Metric") + continue + } + + name, ok := metric.Labels["name"] + if !ok { + log.Debugf("no name in Metric %v", metric.Labels) + } + + source, ok := metric.Labels["source"] + if !ok { + log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name) + } else { + if srctype, ok := metric.Labels["type"]; ok { + source = srctype + ":" + source + } + } + + value := m.(prom2json.Metric).Value + machine := metric.Labels["machine"] + bouncer := metric.Labels["bouncer"] + + route := metric.Labels["route"] + method := metric.Labels["method"] + + reason := metric.Labels["reason"] + origin := metric.Labels["origin"] + action := metric.Labels["action"] + + appsecEngine := metric.Labels["appsec_engine"] + appsecRule := metric.Labels["rule_name"] + + mtype := metric.Labels["type"] + + fval, err := strconv.ParseFloat(value, 32) + if err != nil { + log.Errorf("Unexpected int value %s : %s", value, err) + } + + ival := int(fval) + + switch fam.Name { + // + // buckets + // + case "cs_bucket_created_total": + mBucket.Process(name, "instantiation", ival) + case "cs_buckets": + mBucket.Process(name, "curr_count", ival) + case "cs_bucket_overflowed_total": + mBucket.Process(name, "overflow", ival) + case "cs_bucket_poured_total": + mBucket.Process(name, "pour", ival) + mAcquis.Process(source, "pour", ival) + case "cs_bucket_underflowed_total": + mBucket.Process(name, "underflow", ival) + // + // parsers + // + case "cs_parser_hits_total": + mAcquis.Process(source, "reads", ival) + case "cs_parser_hits_ok_total": + mAcquis.Process(source, "parsed", ival) + case "cs_parser_hits_ko_total": + mAcquis.Process(source, "unparsed", ival) + case "cs_node_hits_total": + mParser.Process(name, "hits", ival) + case "cs_node_hits_ok_total": + mParser.Process(name, "parsed", ival) + case "cs_node_hits_ko_total": + mParser.Process(name, "unparsed", ival) + // + // whitelists + // + case "cs_node_wl_hits_total": + mWhitelist.Process(name, reason, "hits", ival) + case "cs_node_wl_hits_ok_total": + mWhitelist.Process(name, reason, "whitelisted", ival) + // track as well whitelisted lines at acquis level + mAcquis.Process(source, "whitelisted", ival) + // + // lapi + // + case "cs_lapi_route_requests_total": + mLapi.Process(route, method, ival) + case "cs_lapi_machine_requests_total": + mLapiMachine.Process(machine, route, method, ival) + case "cs_lapi_bouncer_requests_total": + mLapiBouncer.Process(bouncer, route, method, ival) + case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total": + mLapiDecision.Process(bouncer, fam.Name, ival) + // + // decisions + // + case "cs_active_decisions": + mDecision.Process(reason, origin, action, ival) + case "cs_alerts": + mAlert.Process(reason, ival) + // + // stash + // + case "cs_cache_size": + mStash.Process(name, mtype, ival) + // + // appsec + // + case "cs_appsec_reqs_total": + mAppsecEngine.Process(appsecEngine, "processed", ival) + case "cs_appsec_block_total": + mAppsecEngine.Process(appsecEngine, "blocked", ival) + case "cs_appsec_rule_hits": + mAppsecRule.Process(appsecEngine, appsecRule, "triggered", ival) + default: + log.Debugf("unknown: %+v", fam.Name) + continue + } + } + } +} + +func (ms metricStore) Format(out io.Writer, wantColor string, sections []string, outputFormat string, noUnit bool) error { + // copy only the sections we want + want := map[string]metricSection{} + + // if explicitly asking for sections, we want to show empty tables + showEmpty := len(sections) > 0 + + // if no sections are specified, we want all of them + if len(sections) == 0 { + sections = maptools.SortedKeys(ms) + } + + for _, section := range sections { + want[section] = ms[section] + } + + switch outputFormat { + case "human": + for _, section := range maptools.SortedKeys(want) { + want[section].Table(out, wantColor, noUnit, showEmpty) + } + case "json": + x, err := json.MarshalIndent(want, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize metrics: %w", err) + } + out.Write(x) + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} diff --git a/cmd/crowdsec-cli/climetrics/table.go b/cmd/crowdsec-cli/climetrics/table.go new file mode 100644 index 00000000000..af13edce2f5 --- /dev/null +++ b/cmd/crowdsec-cli/climetrics/table.go @@ -0,0 +1,122 @@ +package climetrics + +import ( + "errors" + "sort" + "strconv" + + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/maptools" +) + +// ErrNilTable means a nil pointer was passed instead of a table instance. This is a programming error. +var ErrNilTable = errors.New("nil table") + +func lapiMetricsToTable(t table.Writer, stats map[string]map[string]map[string]int) int { + // stats: machine -> route -> method -> count + // sort keys to keep consistent order when printing + machineKeys := []string{} + for k := range stats { + machineKeys = append(machineKeys, k) + } + + sort.Strings(machineKeys) + + numRows := 0 + + for _, machine := range machineKeys { + // oneRow: route -> method -> count + machineRow := stats[machine] + for routeName, route := range machineRow { + for methodName, count := range route { + row := table.Row{ + machine, + routeName, + methodName, + } + if count != 0 { + row = append(row, strconv.Itoa(count)) + } else { + row = append(row, "-") + } + + t.AppendRow(row) + + numRows++ + } + } + } + + return numRows +} + +func wlMetricsToTable(t table.Writer, stats map[string]map[string]map[string]int, noUnit bool) (int, error) { + if t == nil { + return 0, ErrNilTable + } + + numRows := 0 + + for _, name := range maptools.SortedKeys(stats) { + for _, reason := range maptools.SortedKeys(stats[name]) { + row := table.Row{ + name, + reason, + "-", + "-", + } + + for _, action := range maptools.SortedKeys(stats[name][reason]) { + value := stats[name][reason][action] + + switch action { + case "whitelisted": + row[3] = strconv.Itoa(value) + case "hits": + row[2] = strconv.Itoa(value) + default: + log.Debugf("unexpected counter '%s' for whitelists = %d", action, value) + } + } + + t.AppendRow(row) + + numRows++ + } + } + + return numRows, nil +} + +func metricsToTable(t table.Writer, stats map[string]map[string]int, keys []string, noUnit bool) (int, error) { + if t == nil { + return 0, ErrNilTable + } + + numRows := 0 + + for _, alabel := range maptools.SortedKeys(stats) { + astats, ok := stats[alabel] + if !ok { + continue + } + + row := table.Row{alabel} + + for _, sl := range keys { + if v, ok := astats[sl]; ok && v != 0 { + row = append(row, formatNumber(int64(v), !noUnit)) + } else { + row = append(row, "-") + } + } + + t.AppendRow(row) + + numRows++ + } + + return numRows, nil +} diff --git a/cmd/crowdsec-cli/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go similarity index 65% rename from cmd/crowdsec-cli/notifications.go rename to cmd/crowdsec-cli/clinotifications/notifications.go index da436420d12..5489faa37c8 100644 --- a/cmd/crowdsec-cli/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -1,14 +1,16 @@ -package main +package clinotifications import ( "context" "encoding/csv" "encoding/json" + "errors" "fmt" "io/fs" "net/url" "os" "path/filepath" + "slices" "strconv" "strings" "time" @@ -21,16 +23,14 @@ import ( "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csprofiles" - "github.com/crowdsecurity/crowdsec/pkg/types" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type NotificationsCfg struct { @@ -39,13 +39,19 @@ type NotificationsCfg struct { ids []uint } -type cliNotifications struct{} +type configGetter func() *csconfig.Config + +type cliNotifications struct { + cfg configGetter +} -func NewCLINotifications() *cliNotifications { - return &cliNotifications{} +func New(cfg configGetter) *cliNotifications { + return &cliNotifications{ + cfg: cfg, + } } -func (cli cliNotifications) NewCommand() *cobra.Command { +func (cli *cliNotifications) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "notifications [action]", Short: "Helper for notification plugin configuration", @@ -53,112 +59,123 @@ func (cli cliNotifications) NewCommand() *cobra.Command { Args: cobra.MinimumNArgs(1), Aliases: []string{"notifications", "notification"}, DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - if err := csConfig.LoadAPIClient(); err != nil { + if err := cfg.LoadAPIClient(); err != nil { return fmt.Errorf("loading api client: %w", err) } - if err := require.Notifications(csConfig); err != nil { - return err - } - return nil + return require.Notifications(cfg) }, } - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewInspectCmd()) - cmd.AddCommand(cli.NewReinjectCmd()) - cmd.AddCommand(cli.NewTestCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newReinjectCmd()) + cmd.AddCommand(cli.newTestCmd()) return cmd } -func getPluginConfigs() (map[string]csplugin.PluginConfig, error) { +func (cli *cliNotifications) getPluginConfigs() (map[string]csplugin.PluginConfig, error) { + cfg := cli.cfg() pcfgs := map[string]csplugin.PluginConfig{} wf := func(path string, info fs.FileInfo, err error) error { if info == nil { return fmt.Errorf("error while traversing directory %s: %w", path, err) } - name := filepath.Join(csConfig.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice + + name := filepath.Join(cfg.ConfigPaths.NotificationDir, info.Name()) // Avoid calling info.Name() twice if (strings.HasSuffix(name, "yaml") || strings.HasSuffix(name, "yml")) && !(info.IsDir()) { ts, err := csplugin.ParsePluginConfigFile(name) if err != nil { return fmt.Errorf("loading notifification plugin configuration with %s: %w", name, err) } + for _, t := range ts { csplugin.SetRequiredFields(&t) pcfgs[t.Name] = t } } + return nil } - if err := filepath.Walk(csConfig.ConfigPaths.NotificationDir, wf); err != nil { + if err := filepath.Walk(cfg.ConfigPaths.NotificationDir, wf); err != nil { return nil, fmt.Errorf("while loading notifification plugin configuration: %w", err) } + return pcfgs, nil } -func getProfilesConfigs() (map[string]NotificationsCfg, error) { - // A bit of a tricky stuf now: reconcile profiles and notification plugins - pcfgs, err := getPluginConfigs() +func (cli *cliNotifications) getProfilesConfigs() (map[string]NotificationsCfg, error) { + cfg := cli.cfg() + // A bit of a tricky stuff now: reconcile profiles and notification plugins + pcfgs, err := cli.getPluginConfigs() if err != nil { return nil, err } + ncfgs := map[string]NotificationsCfg{} for _, pc := range pcfgs { ncfgs[pc.Name] = NotificationsCfg{ Config: pc, } } - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) + + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { return nil, fmt.Errorf("while extracting profiles from configuration: %w", err) } + for profileID, profile := range profiles { for _, notif := range profile.Cfg.Notifications { pc, ok := pcfgs[notif] if !ok { return nil, fmt.Errorf("notification plugin '%s' does not exist", notif) } + tmp, ok := ncfgs[pc.Name] if !ok { return nil, fmt.Errorf("notification plugin '%s' does not exist", pc.Name) } + tmp.Profiles = append(tmp.Profiles, profile.Cfg) tmp.ids = append(tmp.ids, uint(profileID)) ncfgs[pc.Name] = tmp } } + return ncfgs, nil } -func (cli cliNotifications) NewListCmd() *cobra.Command { +func (cli *cliNotifications) newListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list", - Short: "list active notifications plugins", - Long: `list active notifications plugins`, + Short: "list notifications plugins", + Long: `list notifications plugins and their status (active or not)`, Example: `cscli notifications list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, arg []string) error { - ncfgs, err := getProfilesConfigs() + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - if csConfig.Cscli.Output == "human" { - notificationListTable(color.Output, ncfgs) - } else if csConfig.Cscli.Output == "json" { + if cfg.Cscli.Output == "human" { + notificationListTable(color.Output, cfg.Cscli.Color, ncfgs) + } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(ncfgs, "", " ") if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) + return fmt.Errorf("failed to serialize notification configuration: %w", err) } fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "raw" { + } else if cfg.Cscli.Output == "raw" { csvwriter := csv.NewWriter(os.Stdout) err := csvwriter.Write([]string{"Name", "Type", "Profile name"}) if err != nil { @@ -176,6 +193,7 @@ func (cli cliNotifications) NewListCmd() *cobra.Command { } csvwriter.Flush() } + return nil }, } @@ -183,44 +201,41 @@ func (cli cliNotifications) NewListCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewInspectCmd() *cobra.Command { +func (cli *cliNotifications) newInspectCmd() *cobra.Command { cmd := &cobra.Command{ Use: "inspect", - Short: "Inspect active notifications plugin configuration", - Long: `Inspect active notifications plugin and show configuration`, + Short: "Inspect notifications plugin", + Long: `Inspect notifications plugin and show configuration`, Example: `cscli notifications inspect `, Args: cobra.ExactArgs(1), + ValidArgsFunction: cli.notificationConfigFilter, DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, args []string) error { - if args[0] == "" { - return fmt.Errorf("please provide a plugin name to inspect") - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - ncfgs, err := getProfilesConfigs() + RunE: func(_ *cobra.Command, args []string) error { + cfg := cli.cfg() + ncfgs, err := cli.getProfilesConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - cfg, ok := ncfgs[args[0]] + ncfg, ok := ncfgs[args[0]] if !ok { return fmt.Errorf("plugin '%s' does not exist or is not active", args[0]) } - if csConfig.Cscli.Output == "human" || csConfig.Cscli.Output == "raw" { - fmt.Printf(" - %15s: %15s\n", "Type", cfg.Config.Type) - fmt.Printf(" - %15s: %15s\n", "Name", cfg.Config.Name) - fmt.Printf(" - %15s: %15s\n", "Timeout", cfg.Config.TimeOut) - fmt.Printf(" - %15s: %15s\n", "Format", cfg.Config.Format) - for k, v := range cfg.Config.Config { + if cfg.Cscli.Output == "human" || cfg.Cscli.Output == "raw" { + fmt.Printf(" - %15s: %15s\n", "Type", ncfg.Config.Type) + fmt.Printf(" - %15s: %15s\n", "Name", ncfg.Config.Name) + fmt.Printf(" - %15s: %15s\n", "Timeout", ncfg.Config.TimeOut) + fmt.Printf(" - %15s: %15s\n", "Format", ncfg.Config.Format) + for k, v := range ncfg.Config.Config { fmt.Printf(" - %15s: %15v\n", k, v) } - } else if csConfig.Cscli.Output == "json" { + } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(cfg, "", " ") if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) + return fmt.Errorf("failed to serialize notification configuration: %w", err) } fmt.Printf("%s", string(x)) } + return nil }, } @@ -228,38 +243,59 @@ func (cli cliNotifications) NewInspectCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewTestCmd() *cobra.Command { +func (cli *cliNotifications) notificationConfigFilter(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + ncfgs, err := cli.getProfilesConfigs() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + var ret []string + + for k := range ncfgs { + if strings.Contains(k, toComplete) && !slices.Contains(args, k) { + ret = append(ret, k) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli cliNotifications) newTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb alertOverride string ) + cmd := &cobra.Command{ Use: "test [plugin name]", Short: "send a generic test alert to notification plugin", - Long: `send a generic test alert to a notification plugin to test configuration even if is not active`, + Long: `send a generic test alert to a notification plugin even if it is not active in profiles`, Example: `cscli notifications test [plugin_name]`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, + ValidArgsFunction: cli.notificationConfigFilter, PreRunE: func(cmd *cobra.Command, args []string) error { - pconfigs, err := getPluginConfigs() + ctx := cmd.Context() + cfg := cli.cfg() + pconfigs, err := cli.getPluginConfigs() if err != nil { return fmt.Errorf("can't build profiles configuration: %w", err) } - cfg, ok := pconfigs[args[0]] + pcfg, ok := pconfigs[args[0]] if !ok { return fmt.Errorf("plugin name: '%s' does not exist", args[0]) } - //Create a single profile with plugin name as notification name - return pluginBroker.Init(csConfig.PluginConfig, []*csconfig.ProfileCfg{ + // Create a single profile with plugin name as notification name + return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{ { Notifications: []string{ - cfg.Name, + pcfg.Name, }, }, - }, csConfig.ConfigPaths) + }, cfg.ConfigPaths) }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { pluginTomb.Go(func() error { pluginBroker.Run(&pluginTomb) return nil @@ -296,15 +332,18 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { CreatedAt: time.Now().UTC().Format(time.RFC3339), } if err := yaml.Unmarshal([]byte(alertOverride), alert); err != nil { - return fmt.Errorf("failed to unmarshal alert override: %w", err) + return fmt.Errorf("failed to parse alert override: %w", err) } + pluginBroker.PluginChannel <- csplugin.ProfileAlert{ ProfileID: uint(0), Alert: alert, } - //time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(fmt.Errorf("terminating")) + + // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent + pluginTomb.Kill(errors.New("terminating")) pluginTomb.Wait() + return nil }, } @@ -313,9 +352,11 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { return cmd } -func (cli cliNotifications) NewReinjectCmd() *cobra.Command { - var alertOverride string - var alert *models.Alert +func (cli *cliNotifications) newReinjectCmd() *cobra.Command { + var ( + alertOverride string + alert *models.Alert + ) cmd := &cobra.Command{ Use: "reinject", @@ -330,23 +371,29 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not DisableAutoGenTag: true, PreRunE: func(cmd *cobra.Command, args []string) error { var err error - alert, err = FetchAlertFromArgString(args[0]) + alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0]) if err != nil { return err } + return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb ) + + ctx := cmd.Context() + cfg := cli.cfg() + if alertOverride != "" { if err := json.Unmarshal([]byte(alertOverride), alert); err != nil { - return fmt.Errorf("can't unmarshal data in the alert flag: %w", err) + return fmt.Errorf("can't parse data in the alert flag: %w", err) } } - err := pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths) + + err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) if err != nil { return fmt.Errorf("can't initialize plugins: %w", err) } @@ -356,7 +403,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return nil }) - profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles) + profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { return fmt.Errorf("cannot extract profiles from configuration: %w", err) } @@ -382,17 +429,18 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not default: time.Sleep(50 * time.Millisecond) log.Info("sleeping\n") - } } + if profile.Cfg.OnSuccess == "break" { log.Infof("The profile %s contains a 'on_success: break' so bailing out", profile.Cfg.Name) break } } - //time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(fmt.Errorf("terminating")) + // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent + pluginTomb.Kill(errors.New("terminating")) pluginTomb.Wait() + return nil }, } @@ -401,28 +449,33 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return cmd } -func FetchAlertFromArgString(toParse string) (*models.Alert, error) { +func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) { + cfg := cli.cfg() + id, err := strconv.Atoi(toParse) if err != nil { return nil, fmt.Errorf("bad alert id %s", toParse) } - apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) + + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) if err != nil { return nil, fmt.Errorf("error parsing the URL of the API: %w", err) } + client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: strfmt.Password(csConfig.API.Client.Credentials.Password), - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), URL: apiURL, VersionPrefix: "v1", }) if err != nil { return nil, fmt.Errorf("error creating the client for the API: %w", err) } - alert, _, err := client.Alerts.GetByID(context.Background(), id) + + alert, _, err := client.Alerts.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) } + return alert, nil } diff --git a/cmd/crowdsec-cli/notifications_table.go b/cmd/crowdsec-cli/clinotifications/notifications_table.go similarity index 52% rename from cmd/crowdsec-cli/notifications_table.go rename to cmd/crowdsec-cli/clinotifications/notifications_table.go index e0f61d9cebe..0b6a3f58efc 100644 --- a/cmd/crowdsec-cli/notifications_table.go +++ b/cmd/crowdsec-cli/clinotifications/notifications_table.go @@ -1,37 +1,46 @@ -package main +package clinotifications import ( "io" "sort" "strings" - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" + "github.com/jedib0t/go-pretty/v6/text" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/emoji" ) -func notificationListTable(out io.Writer, ncfgs map[string]NotificationsCfg) { - t := newLightTable(out) +func notificationListTable(out io.Writer, wantColor string, ncfgs map[string]NotificationsCfg) { + t := cstable.NewLight(out, wantColor) t.SetHeaders("Active", "Name", "Type", "Profile name") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) + t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft, text.AlignLeft) + t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft, text.AlignLeft) + keys := make([]string, 0, len(ncfgs)) for k := range ncfgs { keys = append(keys, k) } + sort.Slice(keys, func(i, j int) bool { return len(ncfgs[keys[i]].Profiles) > len(ncfgs[keys[j]].Profiles) }) + for _, k := range keys { b := ncfgs[k] profilesList := []string{} + for _, p := range b.Profiles { profilesList = append(profilesList, p.Name) } - active := emoji.CheckMark.String() + + active := emoji.CheckMark if len(profilesList) == 0 { - active = emoji.Prohibited.String() + active = emoji.Prohibited } + t.AddRow(active, b.Config.Name, b.Config.Type, strings.Join(profilesList, ", ")) } + t.Render() } diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go new file mode 100644 index 00000000000..461215c3a39 --- /dev/null +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -0,0 +1,174 @@ +package clipapi + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiserver" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +type configGetter = func() *csconfig.Config + +type cliPapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliPapi { + return &cliPapi{ + cfg: cfg, + } +} + +func (cli *cliPapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "papi [action]", + Short: "Manage interaction with Polling API (PAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := require.CAPI(cfg); err != nil { + return err + } + + return require.PAPI(cfg) + }, + } + + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newSyncCmd()) + + return cmd +} + +func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + perms, err := papi.GetPermissions(ctx) + if err != nil { + return fmt.Errorf("unable to get PAPI permissions: %w", err) + } + + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) + if err != nil { + lastTimestampStr = ptr.Of("never") + } + + // both can and did happen + if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" { + lastTimestampStr = ptr.Of("never") + } + + fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n") + fmt.Fprintf(out, "Console plan: %s\n", perms.Plan) + fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr) + fmt.Fprint(out, "PAPI subscriptions:\n") + + for _, sub := range perms.Categories { + fmt.Fprintf(out, " - %s\n", sub) + } + + return nil +} + +func (cli *cliPapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Get status of the Polling API", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.Status(ctx, color.Output, db) + }, + } + + return cmd +} + +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + t := tomb.Tomb{} + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + t.Go(func() error { return apic.Push(ctx) }) + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + t.Go(papi.SyncDecisions) + + err = papi.PullOnce(time.Time{}, true) + if err != nil { + return fmt.Errorf("unable to sync decisions: %w", err) + } + + log.Infof("Sending acknowledgements to CAPI") + + apic.Shutdown() + papi.Shutdown() + t.Wait() + time.Sleep(5 * time.Second) // FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done + + return nil +} + +func (cli *cliPapi) newSyncCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "sync", + Short: "Sync with the Polling API, pulling all non-expired orders for the instance", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.sync(ctx, color.Output, db) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clisetup/setup.go b/cmd/crowdsec-cli/clisetup/setup.go new file mode 100644 index 00000000000..269cdfb78e9 --- /dev/null +++ b/cmd/crowdsec-cli/clisetup/setup.go @@ -0,0 +1,307 @@ +package clisetup + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + + goccyyaml "github.com/goccy/go-yaml" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/setup" +) + +type configGetter func() *csconfig.Config + +type cliSetup struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSetup { + return &cliSetup{ + cfg: cfg, + } +} + +func (cli *cliSetup) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "setup", + Short: "Tools to configure crowdsec", + Long: "Manage hub configuration and service detection", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newDetectCmd()) + cmd.AddCommand(cli.newInstallHubCmd()) + cmd.AddCommand(cli.newDataSourcesCmd()) + cmd.AddCommand(cli.newValidateCmd()) + + return cmd +} + +type detectFlags struct { + detectConfigFile string + listSupportedServices bool + forcedUnits []string + forcedProcesses []string + forcedOSFamily string + forcedOSID string + forcedOSVersion string + skipServices []string + snubSystemd bool + outYaml bool +} + +func (f *detectFlags) bind(cmd *cobra.Command) { + defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") + + flags := cmd.Flags() + flags.StringVar(&f.detectConfigFile, "detect-config", defaultServiceDetect, "path to service detection configuration") + flags.BoolVar(&f.listSupportedServices, "list-supported-services", false, "do not detect; only print supported services") + flags.StringSliceVar(&f.forcedUnits, "force-unit", nil, "force detection of a systemd unit (can be repeated)") + flags.StringSliceVar(&f.forcedProcesses, "force-process", nil, "force detection of a running process (can be repeated)") + flags.StringSliceVar(&f.skipServices, "skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") + flags.StringVar(&f.forcedOSFamily, "force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") + flags.StringVar(&f.forcedOSID, "force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") + flags.StringVar(&f.forcedOSVersion, "force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") + flags.BoolVar(&f.snubSystemd, "snub-systemd", false, "don't use systemd, even if available") + flags.BoolVar(&f.outYaml, "yaml", false, "output yaml, not json") +} + +func (cli *cliSetup) newDetectCmd() *cobra.Command { + f := detectFlags{} + + cmd := &cobra.Command{ + Use: "detect", + Short: "detect running services, generate a setup file", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.detect(f) + }, + } + + f.bind(cmd) + + return cmd +} + +func (cli *cliSetup) newInstallHubCmd() *cobra.Command { + var dryRun bool + + cmd := &cobra.Command{ + Use: "install-hub [setup_file] [flags]", + Short: "install items from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.install(cmd.Context(), dryRun, args[0]) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&dryRun, "dry-run", false, "don't install anything; print out what would have been") + + return cmd +} + +func (cli *cliSetup) newDataSourcesCmd() *cobra.Command { + var toDir string + + cmd := &cobra.Command{ + Use: "datasources [setup_file] [flags]", + Short: "generate datasource (acquisition) configuration from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.dataSources(args[0], toDir) + }, + } + + flags := cmd.Flags() + flags.StringVar(&toDir, "to-dir", "", "write the configuration to a directory, in multiple files") + + return cmd +} + +func (cli *cliSetup) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate [setup_file]", + Short: "validate a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(args[0]) + }, + } + + return cmd +} + +func (cli *cliSetup) detect(f detectFlags) error { + var ( + detectReader *os.File + err error + ) + + switch f.detectConfigFile { + case "-": + log.Tracef("Reading detection rules from stdin") + + detectReader = os.Stdin + default: + log.Tracef("Reading detection rules: %s", f.detectConfigFile) + + detectReader, err = os.Open(f.detectConfigFile) + if err != nil { + return err + } + } + + if !f.snubSystemd { + _, err = exec.LookPath("systemctl") + if err != nil { + log.Debug("systemctl not available: snubbing systemd") + + f.snubSystemd = true + } + } + + if f.forcedOSFamily == "" && f.forcedOSID != "" { + log.Debug("force-os-id is set: force-os-family defaults to 'linux'") + + f.forcedOSFamily = "linux" + } + + if f.listSupportedServices { + supported, err := setup.ListSupported(detectReader) + if err != nil { + return err + } + + for _, svc := range supported { + fmt.Println(svc) + } + + return nil + } + + opts := setup.DetectOptions{ + ForcedUnits: f.forcedUnits, + ForcedProcesses: f.forcedProcesses, + ForcedOS: setup.ExprOS{ + Family: f.forcedOSFamily, + ID: f.forcedOSID, + RawVersion: f.forcedOSVersion, + }, + SkipServices: f.skipServices, + SnubSystemd: f.snubSystemd, + } + + hubSetup, err := setup.Detect(detectReader, opts) + if err != nil { + return fmt.Errorf("detecting services: %w", err) + } + + setup, err := setupAsString(hubSetup, f.outYaml) + if err != nil { + return err + } + + fmt.Println(setup) + + return nil +} + +func setupAsString(cs setup.Setup, outYaml bool) (string, error) { + var ( + ret []byte + err error + ) + + wrap := func(err error) error { + return fmt.Errorf("while serializing setup: %w", err) + } + + indentLevel := 2 + buf := &bytes.Buffer{} + enc := yaml.NewEncoder(buf) + enc.SetIndent(indentLevel) + + if err = enc.Encode(cs); err != nil { + return "", wrap(err) + } + + if err = enc.Close(); err != nil { + return "", wrap(err) + } + + ret = buf.Bytes() + + if !outYaml { + // take a general approach to output json, so we avoid the + // double tags in the structures and can use go-yaml features + // missing from the json package + ret, err = goccyyaml.YAMLToJSON(ret) + if err != nil { + return "", wrap(err) + } + } + + return string(ret), nil +} + +func (cli *cliSetup) dataSources(fromFile string, toDir string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading setup file: %w", err) + } + + output, err := setup.DataSources(input, toDir) + if err != nil { + return err + } + + if toDir == "" { + fmt.Println(output) + } + + return nil +} + +func (cli *cliSetup) install(ctx context.Context, dryRun bool, fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading file %s: %w", fromFile, err) + } + + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + return setup.InstallHubItems(ctx, hub, input, dryRun) +} + +func (cli *cliSetup) validate(fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading stdin: %w", err) + } + + if err = setup.Validate(input); err != nil { + fmt.Printf("%v\n", err) + return errors.New("invalid setup file") + } + + return nil +} diff --git a/cmd/crowdsec-cli/clisimulation/simulation.go b/cmd/crowdsec-cli/clisimulation/simulation.go new file mode 100644 index 00000000000..8136aa213c3 --- /dev/null +++ b/cmd/crowdsec-cli/clisimulation/simulation.go @@ -0,0 +1,286 @@ +package clisimulation + +import ( + "errors" + "fmt" + "os" + "slices" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +type configGetter func() *csconfig.Config + +type cliSimulation struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSimulation { + return &cliSimulation{ + cfg: cfg, + } +} + +func (cli *cliSimulation) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "simulation [command]", + Short: "Manage simulation status of scenarios", + Example: `cscli simulation status +cscli simulation enable crowdsecurity/ssh-bf +cscli simulation disable crowdsecurity/ssh-bf`, + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadSimulation(); err != nil { + return err + } + if cli.cfg().Cscli.SimulationConfig == nil { + return errors.New("no simulation configured") + } + + return nil + }, + PersistentPostRun: func(cmd *cobra.Command, _ []string) { + if cmd.Name() != "status" { + log.Info(reload.Message) + } + }, + } + cmd.Flags().SortFlags = false + cmd.PersistentFlags().SortFlags = false + + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) + + return cmd +} + +func (cli *cliSimulation) newEnableCmd() *cobra.Command { + var forceGlobalSimulation bool + + cmd := &cobra.Command{ + Use: "enable [scenario] [-global]", + Short: "Enable the simulation, globally or on specified scenarios", + Example: `cscli simulation enable`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + if len(args) > 0 { + for _, scenario := range args { + item := hub.GetItem(cwhub.SCENARIOS, scenario) + if item == nil { + log.Errorf("'%s' doesn't exist or is not a scenario", scenario) + continue + } + if !item.State.Installed { + log.Warningf("'%s' isn't enabled", scenario) + } + isExcluded := slices.Contains(cli.cfg().Cscli.SimulationConfig.Exclusions, scenario) + if *cli.cfg().Cscli.SimulationConfig.Simulation && !isExcluded { + log.Warning("global simulation is already enabled") + continue + } + if !*cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + log.Warningf("simulation for '%s' already enabled", scenario) + continue + } + if *cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + cli.removeFromExclusion(scenario) + log.Printf("simulation enabled for '%s'", scenario) + continue + } + cli.addToExclusion(scenario) + log.Printf("simulation mode for '%s' enabled", scenario) + } + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("simulation enable: %w", err) + } + } else if forceGlobalSimulation { + if err := cli.enableGlobalSimulation(); err != nil { + return fmt.Errorf("unable to enable global simulation mode: %w", err) + } + } else { + _ = cmd.Help() + } + + return nil + }, + } + cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Enable global simulation (reverse mode)") + + return cmd +} + +func (cli *cliSimulation) newDisableCmd() *cobra.Command { + var forceGlobalSimulation bool + + cmd := &cobra.Command{ + Use: "disable [scenario]", + Short: "Disable the simulation mode. Disable only specified scenarios", + Example: `cscli simulation disable`, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + for _, scenario := range args { + isExcluded := slices.Contains(cli.cfg().Cscli.SimulationConfig.Exclusions, scenario) + if !*cli.cfg().Cscli.SimulationConfig.Simulation && !isExcluded { + log.Warningf("%s isn't in simulation mode", scenario) + continue + } + if !*cli.cfg().Cscli.SimulationConfig.Simulation && isExcluded { + cli.removeFromExclusion(scenario) + log.Printf("simulation mode for '%s' disabled", scenario) + continue + } + if isExcluded { + log.Warningf("simulation mode is enabled but is already disable for '%s'", scenario) + continue + } + cli.addToExclusion(scenario) + log.Printf("simulation mode for '%s' disabled", scenario) + } + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("simulation disable: %w", err) + } + } else if forceGlobalSimulation { + if err := cli.disableGlobalSimulation(); err != nil { + return fmt.Errorf("unable to disable global simulation mode: %w", err) + } + } else { + _ = cmd.Help() + } + + return nil + }, + } + cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Disable global simulation (reverse mode)") + + return cmd +} + +func (cli *cliSimulation) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Show simulation mode status", + Example: `cscli simulation status`, + DisableAutoGenTag: true, + Run: func(_ *cobra.Command, _ []string) { + cli.status() + }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + }, + } + + return cmd +} + +func (cli *cliSimulation) addToExclusion(name string) { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Exclusions = append(cfg.Cscli.SimulationConfig.Exclusions, name) +} + +func (cli *cliSimulation) removeFromExclusion(name string) { + cfg := cli.cfg() + index := slices.Index(cfg.Cscli.SimulationConfig.Exclusions, name) + + // Remove element from the slice + cfg.Cscli.SimulationConfig.Exclusions[index] = cfg.Cscli.SimulationConfig.Exclusions[len(cfg.Cscli.SimulationConfig.Exclusions)-1] + cfg.Cscli.SimulationConfig.Exclusions[len(cfg.Cscli.SimulationConfig.Exclusions)-1] = "" + cfg.Cscli.SimulationConfig.Exclusions = cfg.Cscli.SimulationConfig.Exclusions[:len(cfg.Cscli.SimulationConfig.Exclusions)-1] +} + +func (cli *cliSimulation) enableGlobalSimulation() error { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Simulation = new(bool) + *cfg.Cscli.SimulationConfig.Simulation = true + cfg.Cscli.SimulationConfig.Exclusions = []string{} + + if err := cli.dumpSimulationFile(); err != nil { + return fmt.Errorf("unable to dump simulation file: %w", err) + } + + log.Printf("global simulation: enabled") + + return nil +} + +func (cli *cliSimulation) dumpSimulationFile() error { + cfg := cli.cfg() + + newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) + if err != nil { + return fmt.Errorf("unable to serialize simulation configuration: %w", err) + } + + err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) + if err != nil { + return fmt.Errorf("write simulation config in '%s' failed: %w", cfg.ConfigPaths.SimulationFilePath, err) + } + + log.Debugf("updated simulation file %s", cfg.ConfigPaths.SimulationFilePath) + + return nil +} + +func (cli *cliSimulation) disableGlobalSimulation() error { + cfg := cli.cfg() + cfg.Cscli.SimulationConfig.Simulation = new(bool) + *cfg.Cscli.SimulationConfig.Simulation = false + + cfg.Cscli.SimulationConfig.Exclusions = []string{} + + newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) + if err != nil { + return fmt.Errorf("unable to serialize new simulation configuration: %w", err) + } + + err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) + if err != nil { + return fmt.Errorf("unable to write new simulation config in '%s': %w", cfg.ConfigPaths.SimulationFilePath, err) + } + + log.Printf("global simulation: disabled") + + return nil +} + +func (cli *cliSimulation) status() { + cfg := cli.cfg() + if cfg.Cscli.SimulationConfig == nil { + log.Printf("global simulation: disabled (configuration file is missing)") + return + } + + if *cfg.Cscli.SimulationConfig.Simulation { + log.Println("global simulation: enabled") + + if len(cfg.Cscli.SimulationConfig.Exclusions) > 0 { + log.Println("Scenarios not in simulation mode :") + + for _, scenario := range cfg.Cscli.SimulationConfig.Exclusions { + log.Printf(" - %s", scenario) + } + } + } else { + log.Println("global simulation: disabled") + + if len(cfg.Cscli.SimulationConfig.Exclusions) > 0 { + log.Println("Scenarios in simulation mode :") + + for _, scenario := range cfg.Cscli.SimulationConfig.Exclusions { + log.Printf(" - %s", scenario) + } + } + } +} diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go new file mode 100644 index 00000000000..4474f5c8f11 --- /dev/null +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -0,0 +1,642 @@ +package clisupport + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/blackfireio/osinfo" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/fflag" +) + +const ( + SUPPORT_METRICS_DIR = "metrics/" + SUPPORT_VERSION_PATH = "version.txt" + SUPPORT_FEATURES_PATH = "features.txt" + SUPPORT_OS_INFO_PATH = "osinfo.txt" + SUPPORT_HUB = "hub.txt" + SUPPORT_BOUNCERS_PATH = "lapi/bouncers.txt" + SUPPORT_AGENTS_PATH = "lapi/agents.txt" + SUPPORT_CROWDSEC_CONFIG_PATH = "config/crowdsec.yaml" + SUPPORT_LAPI_STATUS_PATH = "lapi_status.txt" + SUPPORT_CAPI_STATUS_PATH = "capi_status.txt" + SUPPORT_PAPI_STATUS_PATH = "papi_status.txt" + SUPPORT_ACQUISITION_DIR = "config/acquis/" + SUPPORT_CROWDSEC_PROFILE_PATH = "config/profiles.yaml" + SUPPORT_CRASH_DIR = "crash/" + SUPPORT_LOG_DIR = "log/" + SUPPORT_PPROF_DIR = "pprof/" +) + +// StringHook collects log entries in a string +type StringHook struct { + LogBuilder strings.Builder + LogLevels []log.Level +} + +func (hook *StringHook) Levels() []log.Level { + return hook.LogLevels +} + +func (hook *StringHook) Fire(entry *log.Entry) error { + logEntry, err := entry.String() + if err != nil { + return err + } + + hook.LogBuilder.WriteString(logEntry) + + return nil +} + +// from https://github.com/acarl005/stripansi +var reStripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") + +func stripAnsiString(str string) string { + // the byte version doesn't strip correctly + return reStripAnsi.ReplaceAllString(str, "") +} + +func (cli *cliSupport) dumpMetrics(ctx context.Context, db *database.Client, zw *zip.Writer) error { + log.Info("Collecting prometheus metrics") + + cfg := cli.cfg() + + if cfg.Cscli.PrometheusUrl == "" { + log.Warn("can't collect metrics: prometheus_uri is not set") + } + + humanMetrics := new(bytes.Buffer) + + ms := climetrics.NewMetricStore() + + if err := ms.Fetch(ctx, cfg.Cscli.PrometheusUrl, db); err != nil { + return err + } + + if err := ms.Format(humanMetrics, cfg.Cscli.Color, nil, "human", false); err != nil { + return fmt.Errorf("could not format prometheus metrics: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cfg.Cscli.PrometheusUrl, nil) + if err != nil { + return fmt.Errorf("could not create request to prometheus endpoint: %w", err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not get metrics from prometheus endpoint: %w", err) + } + + defer resp.Body.Close() + + cli.writeToZip(zw, SUPPORT_METRICS_DIR+"metrics.prometheus", time.Now(), resp.Body) + + stripped := stripAnsiString(humanMetrics.String()) + + cli.writeToZip(zw, SUPPORT_METRICS_DIR+"metrics.human", time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpVersion(zw *zip.Writer) { + log.Info("Collecting version") + + cli.writeToZip(zw, SUPPORT_VERSION_PATH, time.Now(), strings.NewReader(cwversion.FullString())) +} + +func (cli *cliSupport) dumpFeatures(zw *zip.Writer) { + log.Info("Collecting feature flags") + + w := new(bytes.Buffer) + for _, k := range fflag.Crowdsec.GetEnabledFeatures() { + fmt.Fprintln(w, k) + } + + cli.writeToZip(zw, SUPPORT_FEATURES_PATH, time.Now(), w) +} + +func (cli *cliSupport) dumpOSInfo(zw *zip.Writer) error { + log.Info("Collecting OS info") + + info, err := osinfo.GetOSInfo() + if err != nil { + return err + } + + w := new(bytes.Buffer) + fmt.Fprintf(w, "Architecture: %s\n", info.Architecture) + fmt.Fprintf(w, "Family: %s\n", info.Family) + fmt.Fprintf(w, "ID: %s\n", info.ID) + fmt.Fprintf(w, "Name: %s\n", info.Name) + fmt.Fprintf(w, "Codename: %s\n", info.Codename) + fmt.Fprintf(w, "Version: %s\n", info.Version) + fmt.Fprintf(w, "Build: %s\n", info.Build) + + cli.writeToZip(zw, SUPPORT_OS_INFO_PATH, time.Now(), w) + + return nil +} + +func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { + log.Infof("Collecting hub") + + if hub == nil { + return errors.New("no hub connection") + } + + out := new(bytes.Buffer) + ch := clihub.New(cli.cfg) + + if err := ch.List(out, hub, false); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_HUB, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting bouncers") + + if db == nil { + return errors.New("no database connection") + } + + out := new(bytes.Buffer) + cb := clibouncer.New(cli.cfg) + + if err := cb.List(ctx, out, db); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_BOUNCERS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting agents") + + if db == nil { + return errors.New("no database connection") + } + + out := new(bytes.Buffer) + cm := climachine.New(cli.cfg) + + if err := cm.List(ctx, out, db); err != nil { + return err + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_AGENTS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpLAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { + log.Info("Collecting LAPI status") + + out := new(bytes.Buffer) + cl := clilapi.New(cli.cfg) + + err := cl.Status(ctx, out, hub) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_LAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpCAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { + log.Info("Collecting CAPI status") + + out := new(bytes.Buffer) + cc := clicapi.New(cli.cfg) + + err := cc.Status(ctx, out, hub) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_CAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpPAPIStatus(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting PAPI status") + + out := new(bytes.Buffer) + cp := clipapi.New(cli.cfg) + + err := cp.Status(ctx, out, db) + if err != nil { + fmt.Fprintf(out, "%s\n", err) + } + + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_PAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpConfigYAML(zw *zip.Writer) error { + log.Info("Collecting crowdsec config") + + cfg := cli.cfg() + + config, err := os.ReadFile(*cfg.FilePath) + if err != nil { + return fmt.Errorf("could not read config file: %w", err) + } + + r := regexp.MustCompile(`(\s+password:|\s+user:|\s+host:)\s+.*`) + + redacted := r.ReplaceAll(config, []byte("$1 ****REDACTED****")) + + cli.writeToZip(zw, SUPPORT_CROWDSEC_CONFIG_PATH, time.Now(), bytes.NewReader(redacted)) + + return nil +} + +func (cli *cliSupport) dumpPprof(ctx context.Context, zw *zip.Writer, prometheusCfg csconfig.PrometheusCfg, endpoint string) error { + log.Infof("Collecting pprof/%s data", endpoint) + + ctx, cancel := context.WithTimeout(ctx, 120*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + fmt.Sprintf( + "http://%s/debug/pprof/%s?debug=1", + net.JoinHostPort( + prometheusCfg.ListenAddr, + strconv.Itoa(prometheusCfg.ListenPort), + ), + endpoint, + ), + nil, + ) + if err != nil { + return fmt.Errorf("could not create request to pprof endpoint: %w", err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not get pprof data from endpoint: %w", err) + } + + defer resp.Body.Close() + + cli.writeToZip(zw, SUPPORT_PPROF_DIR+endpoint+".pprof", time.Now(), resp.Body) + + return nil +} + +func (cli *cliSupport) dumpProfiles(zw *zip.Writer) { + log.Info("Collecting crowdsec profile") + + cfg := cli.cfg() + cli.writeFileToZip(zw, SUPPORT_CROWDSEC_PROFILE_PATH, cfg.API.Server.ProfilesPath) +} + +func (cli *cliSupport) dumpAcquisitionConfig(zw *zip.Writer) { + log.Info("Collecting acquisition config") + + cfg := cli.cfg() + + for _, filename := range cfg.Crowdsec.AcquisitionFiles { + fname := strings.ReplaceAll(filename, string(filepath.Separator), "___") + cli.writeFileToZip(zw, SUPPORT_ACQUISITION_DIR+fname, filename) + } +} + +func (cli *cliSupport) dumpLogs(zw *zip.Writer) error { + log.Info("Collecting CrowdSec logs") + + cfg := cli.cfg() + + logDir := cfg.Common.LogDir + + logFiles, err := filepath.Glob(filepath.Join(logDir, "crowdsec*.log")) + if err != nil { + return fmt.Errorf("could not list log files: %w", err) + } + + for _, filename := range logFiles { + cli.writeFileToZip(zw, SUPPORT_LOG_DIR+filepath.Base(filename), filename) + } + + return nil +} + +func (cli *cliSupport) dumpCrash(zw *zip.Writer) error { + log.Info("Collecting crash dumps") + + traceFiles, err := trace.List() + if err != nil { + return fmt.Errorf("could not list crash dumps: %w", err) + } + + for _, filename := range traceFiles { + cli.writeFileToZip(zw, SUPPORT_CRASH_DIR+filepath.Base(filename), filename) + } + + return nil +} + +type configGetter func() *csconfig.Config + +type cliSupport struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSupport { + return &cliSupport{ + cfg: cfg, + } +} + +func (cli *cliSupport) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "support [action]", + Short: "Provide commands to help during support", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.NewDumpCmd()) + + return cmd +} + +// writeToZip adds a file to the zip archive, from a reader +func (cli *cliSupport) writeToZip(zipWriter *zip.Writer, filename string, mtime time.Time, reader io.Reader) { + header := &zip.FileHeader{ + Name: filename, + Method: zip.Deflate, + Modified: mtime, + } + + fw, err := zipWriter.CreateHeader(header) + if err != nil { + log.Errorf("could not add zip entry for %s: %s", filename, err) + return + } + + _, err = io.Copy(fw, reader) + if err != nil { + log.Errorf("could not write zip entry for %s: %s", filename, err) + } +} + +// writeFileToZip adds a file to the zip archive, from a file, and retains the mtime +func (cli *cliSupport) writeFileToZip(zw *zip.Writer, filename string, fromFile string) { + mtime := time.Now() + + fi, err := os.Stat(fromFile) + if err == nil { + mtime = fi.ModTime() + } + + fin, err := os.Open(fromFile) + if err != nil { + log.Errorf("could not open file %s: %s", fromFile, err) + return + } + defer fin.Close() + + cli.writeToZip(zw, filename, mtime, fin) +} + +func (cli *cliSupport) dump(ctx context.Context, outFile string) error { + var skipCAPI, skipLAPI, skipAgent bool + + collector := &StringHook{ + LogLevels: log.AllLevels, + } + log.AddHook(collector) + + cfg := cli.cfg() + + if outFile == "" { + outFile = filepath.Join(os.TempDir(), "crowdsec-support.zip") + } + + w := bytes.NewBuffer(nil) + zipWriter := zip.NewWriter(w) + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + log.Warn(err) + } + + if err = cfg.LoadAPIServer(true); err != nil { + log.Warnf("could not load LAPI, skipping CAPI check") + + skipCAPI = true + } + + if err = cfg.LoadCrowdsec(); err != nil { + log.Warnf("could not load agent config, skipping crowdsec config check") + + skipAgent = true + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + log.Warn("Could not init hub, running on LAPI ? Hub related information will not be collected") + // XXX: lapi status check requires scenarios, will return an error + } + + if cfg.API.Client == nil || cfg.API.Client.Credentials == nil { + log.Warn("no agent credentials found, skipping LAPI connectivity check") + + skipLAPI = true + } + + if cfg.API.Server == nil || cfg.API.Server.OnlineClient == nil || cfg.API.Server.OnlineClient.Credentials == nil { + log.Warn("no CAPI credentials found, skipping CAPI connectivity check") + + skipCAPI = true + } + + if err = cli.dumpMetrics(ctx, db, zipWriter); err != nil { + log.Warn(err) + } + + if err = cli.dumpOSInfo(zipWriter); err != nil { + log.Warnf("could not collect OS information: %s", err) + } + + if err = cli.dumpConfigYAML(zipWriter); err != nil { + log.Warnf("could not collect main config file: %s", err) + } + + if err = cli.dumpHubItems(zipWriter, hub); err != nil { + log.Warnf("could not collect hub information: %s", err) + } + + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect bouncers information: %s", err) + } + + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect agents information: %s", err) + } + + if !skipCAPI { + if err = cli.dumpCAPIStatus(ctx, zipWriter, hub); err != nil { + log.Warnf("could not collect CAPI status: %s", err) + } + + if err = cli.dumpPAPIStatus(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect PAPI status: %s", err) + } + } + + if !skipLAPI { + if err = cli.dumpLAPIStatus(ctx, zipWriter, hub); err != nil { + log.Warnf("could not collect LAPI status: %s", err) + } + + // call pprof separately, one might fail for timeout + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "goroutine"); err != nil { + log.Warnf("could not collect pprof goroutine data: %s", err) + } + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "heap"); err != nil { + log.Warnf("could not collect pprof heap data: %s", err) + } + + if err = cli.dumpPprof(ctx, zipWriter, *cfg.Prometheus, "profile"); err != nil { + log.Warnf("could not collect pprof cpu data: %s", err) + } + + cli.dumpProfiles(zipWriter) + } + + if !skipAgent { + cli.dumpAcquisitionConfig(zipWriter) + } + + if err = cli.dumpCrash(zipWriter); err != nil { + log.Warnf("could not collect crash dumps: %s", err) + } + + if err = cli.dumpLogs(zipWriter); err != nil { + log.Warnf("could not collect log files: %s", err) + } + + cli.dumpVersion(zipWriter) + cli.dumpFeatures(zipWriter) + + // log of the dump process, without color codes + collectedOutput := stripAnsiString(collector.LogBuilder.String()) + + cli.writeToZip(zipWriter, "dump.log", time.Now(), strings.NewReader(collectedOutput)) + + err = zipWriter.Close() + if err != nil { + return fmt.Errorf("could not finalize zip file: %w", err) + } + + if outFile == "-" { + _, err = os.Stdout.Write(w.Bytes()) + return err + } + + err = os.WriteFile(outFile, w.Bytes(), 0o600) + if err != nil { + return fmt.Errorf("could not write zip file to %s: %w", outFile, err) + } + + log.Infof("Written zip file to %s", outFile) + + return nil +} + +func (cli *cliSupport) NewDumpCmd() *cobra.Command { + var outFile string + + cmd := &cobra.Command{ + Use: "dump", + Short: "Dump all your configuration to a zip file for easier support", + Long: `Dump the following information: +- Crowdsec version +- OS version +- Enabled feature flags +- Latest Crowdsec logs (log processor, LAPI, remediation components) +- Installed collections, parsers, scenarios... +- Bouncers and machines list +- CAPI/LAPI status +- Crowdsec config (sensitive information like username and password are redacted) +- Crowdsec metrics +- Stack trace in case of process crash`, + Example: `cscli support dump +cscli support dump -f /tmp/crowdsec-support.zip +`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + output := cli.cfg().Cscli.Output + if output != "human" { + return fmt.Errorf("output format %s not supported for this command", output) + } + return cli.dump(cmd.Context(), outFile) + }, + } + + cmd.Flags().StringVarP(&outFile, "outFile", "f", "", "File to dump the information to") + + return cmd +} diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index e60246db790..e88845798e2 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -4,19 +4,29 @@ import ( "github.com/spf13/cobra" ) -func NewConfigCmd() *cobra.Command { - cmdConfig := &cobra.Command{ +type cliConfig struct { + cfg configGetter +} + +func NewCLIConfig(cfg configGetter) *cliConfig { + return &cliConfig{ + cfg: cfg, + } +} + +func (cli *cliConfig) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "config [command]", Short: "Allows to view current config", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, } - cmdConfig.AddCommand(NewConfigShowCmd()) - cmdConfig.AddCommand(NewConfigShowYAMLCmd()) - cmdConfig.AddCommand(NewConfigBackupCmd()) - cmdConfig.AddCommand(NewConfigRestoreCmd()) - cmdConfig.AddCommand(NewConfigFeatureFlagsCmd()) + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newShowYAMLCmd()) + cmd.AddCommand(cli.newBackupCmd()) + cmd.AddCommand(cli.newRestoreCmd()) + cmd.AddCommand(cli.newFeatureFlagsCmd()) - return cmdConfig + return cmd } diff --git a/cmd/crowdsec-cli/config_backup.go b/cmd/crowdsec-cli/config_backup.go index 9414fa51033..d23aff80a78 100644 --- a/cmd/crowdsec-cli/config_backup.go +++ b/cmd/crowdsec-cli/config_backup.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -13,16 +14,14 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func backupHub(dirPath string) error { - hub, err := require.Hub(csConfig, nil, nil) +func (cli *cliConfig) backupHub(dirPath string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) if err != nil { return err } for _, itemType := range cwhub.ItemTypes { - clog := log.WithFields(log.Fields{ - "type": itemType, - }) + clog := log.WithField("type", itemType) itemMap := hub.GetItemMap(itemType) if itemMap == nil { @@ -32,27 +31,25 @@ func backupHub(dirPath string) error { itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itemType) if err = os.MkdirAll(itemDirectory, os.ModePerm); err != nil { - return fmt.Errorf("error while creating %s : %s", itemDirectory, err) + return fmt.Errorf("error while creating %s: %w", itemDirectory, err) } upstreamParsers := []string{} for k, v := range itemMap { - clog = clog.WithFields(log.Fields{ - "file": v.Name, - }) - if !v.State.Installed { //only backup installed ones - clog.Debugf("[%s] : not installed", k) + clog = clog.WithField("file", v.Name) + if !v.State.Installed { // only backup installed ones + clog.Debugf("[%s]: not installed", k) continue } - //for the local/tainted ones, we back up the full file + // for the local/tainted ones, we back up the full file if v.State.Tainted || v.State.IsLocal() || !v.State.UpToDate { - //we need to backup stages for parsers + // we need to backup stages for parsers if itemType == cwhub.PARSERS || itemType == cwhub.POSTOVERFLOWS { fstagedir := fmt.Sprintf("%s%s", itemDirectory, v.Stage) if err = os.MkdirAll(fstagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage dir %s : %s", fstagedir, err) + return fmt.Errorf("error while creating stage dir %s: %w", fstagedir, err) } } @@ -60,7 +57,7 @@ func backupHub(dirPath string) error { tfile := fmt.Sprintf("%s%s/%s", itemDirectory, v.Stage, v.FileName) if err = CopyFile(v.State.LocalPath, tfile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itemType, v.State.LocalPath, tfile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itemType, v.State.LocalPath, tfile, err) } clog.Infof("local/tainted saved %s to %s", v.State.LocalPath, tfile) @@ -68,21 +65,21 @@ func backupHub(dirPath string) error { continue } - clog.Debugf("[%s] : from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) + clog.Debugf("[%s]: from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) clog.Infof("saving, version:%s, up-to-date:%t", v.Version, v.State.UpToDate) upstreamParsers = append(upstreamParsers, v.Name) } - //write the upstream items + // write the upstream items upstreamParsersFname := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itemType) upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") if err != nil { - return fmt.Errorf("failed marshaling upstream parsers : %s", err) + return fmt.Errorf("failed to serialize upstream parsers: %w", err) } err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0o644) if err != nil { - return fmt.Errorf("unable to write to %s %s : %s", itemType, upstreamParsersFname, err) + return fmt.Errorf("unable to write to %s %s: %w", itemType, upstreamParsersFname, err) } clog.Infof("Wrote %d entries for %s to %s", len(upstreamParsers), itemType, upstreamParsersFname) @@ -102,11 +99,13 @@ func backupHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func backupConfigToDirectory(dirPath string) error { +func (cli *cliConfig) backup(dirPath string) error { var err error + cfg := cli.cfg() + if dirPath == "" { - return fmt.Errorf("directory path can't be empty") + return errors.New("directory path can't be empty") } log.Infof("Starting configuration backup") @@ -121,10 +120,10 @@ func backupConfigToDirectory(dirPath string) error { return fmt.Errorf("while creating %s: %w", dirPath, err) } - if csConfig.ConfigPaths.SimulationFilePath != "" { + if cfg.ConfigPaths.SimulationFilePath != "" { backupSimulation := filepath.Join(dirPath, "simulation.yaml") - if err = CopyFile(csConfig.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", csConfig.ConfigPaths.SimulationFilePath, backupSimulation, err) + if err = CopyFile(cfg.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.ConfigPaths.SimulationFilePath, backupSimulation, err) } log.Infof("Saved simulation to %s", backupSimulation) @@ -134,22 +133,22 @@ func backupConfigToDirectory(dirPath string) error { - backup AcquisitionFilePath - backup the other files of acquisition directory */ - if csConfig.Crowdsec != nil && csConfig.Crowdsec.AcquisitionFilePath != "" { + if cfg.Crowdsec != nil && cfg.Crowdsec.AcquisitionFilePath != "" { backupAcquisition := filepath.Join(dirPath, "acquis.yaml") - if err = CopyFile(csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition, err) + if err = CopyFile(cfg.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.Crowdsec.AcquisitionFilePath, backupAcquisition, err) } } acquisBackupDir := filepath.Join(dirPath, "acquis") if err = os.Mkdir(acquisBackupDir, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %s", acquisBackupDir, err) + return fmt.Errorf("error while creating %s: %w", acquisBackupDir, err) } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { /*if it was the default one, it was already backup'ed*/ - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { continue } @@ -169,56 +168,48 @@ func backupConfigToDirectory(dirPath string) error { if ConfigFilePath != "" { backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if err = CopyFile(ConfigFilePath, backupMain); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", ConfigFilePath, backupMain, err) + return fmt.Errorf("failed copy %s to %s: %w", ConfigFilePath, backupMain, err) } log.Infof("Saved default yaml to %s", backupMain) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.OnlineClient != nil && cfg.API.Server.OnlineClient.CredentialsFilePath != "" { backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) + if err = CopyFile(cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) } log.Infof("Saved online API credentials to %s", backupCAPICreds) } - if csConfig.API != nil && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Client != nil && cfg.API.Client.CredentialsFilePath != "" { backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Client.CredentialsFilePath, backupLAPICreds, err) + if err = CopyFile(cfg.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Client.CredentialsFilePath, backupLAPICreds, err) } log.Infof("Saved local API credentials to %s", backupLAPICreds) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.ProfilesPath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.ProfilesPath != "" { backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.ProfilesPath, backupProfiles); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.ProfilesPath, backupProfiles, err) + if err = CopyFile(cfg.API.Server.ProfilesPath, backupProfiles); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.ProfilesPath, backupProfiles, err) } log.Infof("Saved profiles to %s", backupProfiles) } - if err = backupHub(dirPath); err != nil { - return fmt.Errorf("failed to backup hub config: %s", err) - } - - return nil -} - -func runConfigBackup(cmd *cobra.Command, args []string) error { - if err := backupConfigToDirectory(args[0]); err != nil { - return fmt.Errorf("failed to backup config: %w", err) + if err = cli.backupHub(dirPath); err != nil { + return fmt.Errorf("failed to backup hub config: %w", err) } return nil } -func NewConfigBackupCmd() *cobra.Command { - cmdConfigBackup := &cobra.Command{ +func (cli *cliConfig) newBackupCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `backup "directory"`, Short: "Backup current config", Long: `Backup the current crowdsec configuration including : @@ -232,8 +223,14 @@ func NewConfigBackupCmd() *cobra.Command { Example: `cscli config backup ./my-backup`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigBackup, + RunE: func(_ *cobra.Command, args []string) error { + if err := cli.backup(args[0]); err != nil { + return fmt.Errorf("failed to backup config: %w", err) + } + + return nil + }, } - return cmdConfigBackup + return cmd } diff --git a/cmd/crowdsec-cli/config_feature_flags.go b/cmd/crowdsec-cli/config_feature_flags.go index fbba1f56736..d1dbe2b93b7 100644 --- a/cmd/crowdsec-cli/config_feature_flags.go +++ b/cmd/crowdsec-cli/config_feature_flags.go @@ -11,14 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - showRetired, err := flags.GetBool("retired") - if err != nil { - return err - } - +func (cli *cliConfig) featureFlags(showRetired bool) error { green := color.New(color.FgGreen).SprintFunc() red := color.New(color.FgRed).SprintFunc() yellow := color.New(color.FgYellow).SprintFunc() @@ -121,18 +114,22 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { return nil } -func NewConfigFeatureFlagsCmd() *cobra.Command { - cmdConfigFeatureFlags := &cobra.Command{ +func (cli *cliConfig) newFeatureFlagsCmd() *cobra.Command { + var showRetired bool + + cmd := &cobra.Command{ Use: "feature-flags", Short: "Displays feature flag status", Long: `Displays the supported feature flags and their current status.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigFeatureFlags, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.featureFlags(showRetired) + }, } - flags := cmdConfigFeatureFlags.Flags() - flags.Bool("retired", false, "Show retired features") + flags := cmd.Flags() + flags.BoolVar(&showRetired, "retired", false, "Show retired features") - return cmdConfigFeatureFlags + return cmd } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index e9c2fa9aa23..c32328485ec 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -1,28 +1,23 @@ package main import ( + "context" "encoding/json" "fmt" - "io" "os" "path/filepath" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type OldAPICfg struct { - MachineID string `json:"machine_id"` - Password string `json:"password"` -} +func (cli *cliConfig) restoreHub(ctx context.Context, dirPath string) error { + cfg := cli.cfg() -func restoreHub(dirPath string) error { - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), nil) + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), nil) if err != nil { return err } @@ -38,14 +33,14 @@ func restoreHub(dirPath string) error { file, err := os.ReadFile(upstreamListFN) if err != nil { - return fmt.Errorf("error while opening %s : %s", upstreamListFN, err) + return fmt.Errorf("error while opening %s: %w", upstreamListFN, err) } var upstreamList []string err = json.Unmarshal(file, &upstreamList) if err != nil { - return fmt.Errorf("error unmarshaling %s : %s", upstreamListFN, err) + return fmt.Errorf("error parsing %s: %w", upstreamListFN, err) } for _, toinstall := range upstreamList { @@ -55,8 +50,7 @@ func restoreHub(dirPath string) error { continue } - err := item.Install(false, false) - if err != nil { + if err = item.Install(ctx, false, false); err != nil { log.Errorf("Error while installing %s : %s", toinstall, err) } } @@ -64,42 +58,43 @@ func restoreHub(dirPath string) error { /*restore the local and tainted items*/ files, err := os.ReadDir(itemDirectory) if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory, err) } for _, file := range files { - //this was the upstream data + // this was the upstream data if file.Name() == fmt.Sprintf("upstream-%s.json", itype) { continue } if itype == cwhub.PARSERS || itype == cwhub.POSTOVERFLOWS { - //we expect a stage here + // we expect a stage here if !file.IsDir() { continue } stage := file.Name() - stagedir := fmt.Sprintf("%s/%s/%s/", csConfig.ConfigPaths.ConfigDir, itype, stage) + stagedir := fmt.Sprintf("%s/%s/%s/", cfg.ConfigPaths.ConfigDir, itype, stage) log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage directory %s : %s", stagedir, err) + return fmt.Errorf("error while creating stage directory %s: %w", stagedir, err) } // find items ifiles, err := os.ReadDir(itemDirectory + "/" + stage + "/") if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory+"/"+stage, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory+"/"+stage, err) } - //finally copy item + + // finally copy item for _, tfile := range ifiles { log.Infof("Going to restore local/tainted [%s]", tfile.Name()) sourceFile := fmt.Sprintf("%s/%s/%s", itemDirectory, stage, tfile.Name()) destinationFile := fmt.Sprintf("%s%s", stagedir, tfile.Name()) if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } log.Infof("restored %s to %s", sourceFile, destinationFile) @@ -107,10 +102,12 @@ func restoreHub(dirPath string) error { } else { log.Infof("Going to restore local/tainted [%s]", file.Name()) sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) - destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) + destinationFile := fmt.Sprintf("%s/%s/%s", cfg.ConfigPaths.ConfigDir, itype, file.Name()) + if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } + log.Infof("restored %s to %s", sourceFile, destinationFile) } } @@ -130,90 +127,64 @@ func restoreHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { +func (cli *cliConfig) restore(ctx context.Context, dirPath string) error { var err error - if !oldBackup { - backupMain := fmt.Sprintf("%s/config.yaml", dirPath) - if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupMain, csConfig.ConfigPaths.ConfigDir, err) - } + cfg := cli.cfg() + + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) + if _, err = os.Stat(backupMain); err == nil { + if cfg.ConfigPaths != nil && cfg.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, cfg.ConfigPaths.ConfigDir, err) } } + } - // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) + // Now we have config.yaml, we should regenerate config struct to have rights paths etc + ConfigFilePath = fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir) - initConfig() + log.Debug("Reloading configuration") - backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) - } - } + csConfig, _, err = loadConfigFor("config") + if err != nil { + return fmt.Errorf("failed to reload configuration: %w", err) + } - backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) - } - } + cfg = cli.cfg() - backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupProfiles, csConfig.API.Server.ProfilesPath, err) - } + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupCAPICreds); err == nil { + if err = CopyFile(backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath, err) } - } else { - var oldAPICfg OldAPICfg - backupOldAPICfg := fmt.Sprintf("%s/api_creds.json", dirPath) + } - jsonFile, err := os.Open(backupOldAPICfg) - if err != nil { - log.Warningf("failed to open %s : %s", backupOldAPICfg, err) - } else { - byteValue, _ := io.ReadAll(jsonFile) - err = json.Unmarshal(byteValue, &oldAPICfg) - if err != nil { - return fmt.Errorf("failed to load json file %s : %s", backupOldAPICfg, err) - } + backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupLAPICreds); err == nil { + if err = CopyFile(backupLAPICreds, cfg.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, cfg.API.Client.CredentialsFilePath, err) + } + } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: oldAPICfg.MachineID, - Password: oldAPICfg.Password, - URL: CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to dump api credentials: %s", err) - } - apiConfigDumpFile := fmt.Sprintf("%s/online_api_credentials.yaml", csConfig.ConfigPaths.ConfigDir) - if csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - apiConfigDumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } - err = os.WriteFile(apiConfigDumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", apiConfigDumpFile, err) - } - log.Infof("Saved API credentials to %s", apiConfigDumpFile) + backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) + if _, err = os.Stat(backupProfiles); err == nil { + if err = CopyFile(backupProfiles, cfg.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, cfg.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { - if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + if err = CopyFile(backupSimulation, cfg.ConfigPaths.SimulationFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, cfg.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ - if csConfig.Crowdsec.AcquisitionDirPath != "" { - if err = os.MkdirAll(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s : %s", csConfig.Crowdsec.AcquisitionDirPath, err) + if cfg.Crowdsec.AcquisitionDirPath != "" { + if err = os.MkdirAll(cfg.Crowdsec.AcquisitionDirPath, 0o700); err != nil { + return fmt.Errorf("error while creating %s: %w", cfg.Crowdsec.AcquisitionDirPath, err) } } @@ -222,16 +193,16 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { if _, err = os.Stat(backupAcquisition); err == nil { log.Debugf("restoring backup'ed %s", backupAcquisition) - if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + if err = CopyFile(backupAcquisition, cfg.Crowdsec.AcquisitionFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, cfg.Crowdsec.AcquisitionFilePath, err) } } - // if there is files in the acquis backup dir, restore them + // if there are files in the acquis backup dir, restore them acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml") if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil { for _, acquisFile := range acquisFiles { - targetFname, err := filepath.Abs(csConfig.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) + targetFname, err := filepath.Abs(cfg.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) if err != nil { return fmt.Errorf("while saving %s to %s: %w", acquisFile, targetFname, err) } @@ -239,17 +210,17 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring %s to %s", acquisFile, targetFname) if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } } } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { log.Infof("backup filepath from dir -> %s", acquisFile) // if it was the default one, it has already been backed up - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { log.Infof("skip this one") continue } @@ -260,37 +231,22 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { } if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } log.Infof("Saved acquis %s to %s", acquisFile, targetFname) } } - if err = restoreHub(dirPath); err != nil { - return fmt.Errorf("failed to restore hub config : %s", err) - } - - return nil -} - -func runConfigRestore(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - oldBackup, err := flags.GetBool("old-backup") - if err != nil { - return err - } - - if err := restoreConfigFromDirectory(args[0], oldBackup); err != nil { - return fmt.Errorf("failed to restore config from %s: %w", args[0], err) + if err = cli.restoreHub(ctx, dirPath); err != nil { + return fmt.Errorf("failed to restore hub config: %w", err) } return nil } -func NewConfigRestoreCmd() *cobra.Command { - cmdConfigRestore := &cobra.Command{ +func (cli *cliConfig) newRestoreCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `restore "directory"`, Short: `Restore config in backup "directory"`, Long: `Restore the crowdsec configuration from specified backup "directory" including: @@ -303,11 +259,16 @@ func NewConfigRestoreCmd() *cobra.Command { - Backup of API credentials (local API and online API)`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigRestore, - } + RunE: func(cmd *cobra.Command, args []string) error { + dirPath := args[0] - flags := cmdConfigRestore.Flags() - flags.BoolP("old-backup", "", false, "To use when you are upgrading crowdsec v0.X to v1.X and you need to restore backup from v0.X") + if err := cli.restore(cmd.Context(), dirPath); err != nil { + return fmt.Errorf("failed to restore config from %s: %w", dirPath, err) + } + + return nil + }, + } - return cmdConfigRestore + return cmd } diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index bab911cc340..2d3ac488ba2 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -6,17 +6,19 @@ import ( "os" "text/template" - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" "github.com/sanity-io/litter" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) -func showConfigKey(key string) error { +func (cli *cliConfig) showKey(key string) error { + cfg := cli.cfg() + type Env struct { Config *csconfig.Config } @@ -30,15 +32,15 @@ func showConfigKey(key string) error { return err } - output, err := expr.Run(program, Env{Config: csConfig}) + output, err := expr.Run(program, Env{Config: cfg}) if err != nil { return err } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human", "raw": // Don't use litter for strings, it adds quotes - // that we didn't have before + // that would break compatibility with previous versions switch output.(type) { case string: fmt.Println(output) @@ -48,16 +50,17 @@ func showConfigKey(key string) error { case "json": data, err := json.MarshalIndent(output, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) } return nil } -var configShowTemplate = `Global: +func (cli *cliConfig) template() string { + return `Global: {{- if .ConfigPaths }} - Configuration Folder : {{.ConfigPaths.ConfigDir}} @@ -100,6 +103,7 @@ API Client: {{- if .API.Server }} Local API Server{{if and .API.Server.Enable (not (ValueBool .API.Server.Enable))}} (disabled){{end}}: - Listen URL : {{.API.Server.ListenURI}} + - Listen Socket : {{.API.Server.ListenSocket}} - Profile File : {{.API.Server.ProfilesPath}} {{- if .API.Server.TLS }} @@ -181,74 +185,74 @@ Central API: {{- end }} {{- end }} ` +} -func runConfigShow(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - if err := csConfig.LoadAPIClient(); err != nil { - log.Errorf("failed to load API client configuration: %s", err) - // don't return, we can still show the configuration - } - - key, err := flags.GetString("key") - if err != nil { - return err - } - - if key != "" { - return showConfigKey(key) - } +func (cli *cliConfig) show() error { + cfg := cli.cfg() - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": // The tests on .Enable look funny because the option has a true default which has // not been set yet (we don't really load the LAPI) and go templates don't dereference // pointers in boolean tests. Prefix notation is the cherry on top. funcs := template.FuncMap{ // can't use generics here - "ValueBool": func(b *bool) bool { return b!=nil && *b }, + "ValueBool": func(b *bool) bool { return b != nil && *b }, } - tmp, err := template.New("config").Funcs(funcs).Parse(configShowTemplate) + tmp, err := template.New("config").Funcs(funcs).Parse(cli.template()) if err != nil { return err } - err = tmp.Execute(os.Stdout, csConfig) + err = tmp.Execute(os.Stdout, cfg) if err != nil { return err } case "json": - data, err := json.MarshalIndent(csConfig, "", " ") + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) case "raw": - data, err := yaml.Marshal(csConfig) + data, err := yaml.Marshal(cfg) if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } - fmt.Printf("%s\n", string(data)) + fmt.Println(string(data)) } return nil } -func NewConfigShowCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ Use: "show", Short: "Displays current config", Long: `Displays the current cli configuration.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShow, + RunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { + log.Errorf("failed to load API client configuration: %s", err) + // don't return, we can still show the configuration + } + + if key != "" { + return cli.showKey(key) + } + + return cli.show() + }, } - flags := cmdConfigShow.Flags() - flags.StringP("key", "", "", "Display only this value (Config.API.Server.ListenURI)") + flags := cmd.Flags() + flags.StringVarP(&key, "key", "", "", "Display only this value (Config.API.Server.ListenURI)") - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/config_showyaml.go b/cmd/crowdsec-cli/config_showyaml.go index 82bc67ffcb8..52daee6a65e 100644 --- a/cmd/crowdsec-cli/config_showyaml.go +++ b/cmd/crowdsec-cli/config_showyaml.go @@ -6,19 +6,21 @@ import ( "github.com/spf13/cobra" ) -func runConfigShowYAML(cmd *cobra.Command, args []string) error { +func (cli *cliConfig) showYAML() error { fmt.Println(mergedConfig) return nil } -func NewConfigShowYAMLCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowYAMLCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "show-yaml", Short: "Displays merged config.yaml + config.yaml.local", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShowYAML, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.showYAML() + }, } - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/console.go b/cmd/crowdsec-cli/console.go deleted file mode 100644 index dcd6fb37f62..00000000000 --- a/cmd/crowdsec-cli/console.go +++ /dev/null @@ -1,390 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "net/url" - "os" - "strings" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/go-cs-lib/version" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func NewConsoleCmd() *cobra.Command { - var cmdConsole = &cobra.Command{ - Use: "console [action]", - Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { - return err - } - if err := require.CAPI(csConfig); err != nil { - return err - } - if err := require.CAPIRegistered(csConfig); err != nil { - return err - } - return nil - }, - } - - name := "" - overwrite := false - tags := []string{} - opts := []string{} - - cmdEnroll := &cobra.Command{ - Use: "enroll [enroll-key]", - Short: "Enroll this instance to https://app.crowdsec.net [requires local API]", - Long: ` -Enroll this instance to https://app.crowdsec.net - -You can get your enrollment key by creating an account on https://app.crowdsec.net. -After running this command your will need to validate the enrollment in the webapp.`, - Example: fmt.Sprintf(`cscli console enroll YOUR-ENROLL-KEY - cscli console enroll --name [instance_name] YOUR-ENROLL-KEY - cscli console enroll --name [instance_name] --tags [tag_1] --tags [tag_2] YOUR-ENROLL-KEY - cscli console enroll --enable context,manual YOUR-ENROLL-KEY - - valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - apiURL, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - return fmt.Errorf("could not parse CAPI URL: %s", err) - } - - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - return err - } - - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get installed scenarios: %s", err) - } - - if len(scenarios) == 0 { - scenarios = make([]string, 0) - } - - enable_opts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} - if len(opts) != 0 { - for _, opt := range opts { - valid := false - if opt == "all" { - enable_opts = csconfig.CONSOLE_CONFIGS - break - } - for _, available_opt := range csconfig.CONSOLE_CONFIGS { - if opt == available_opt { - valid = true - enable := true - for _, enabled_opt := range enable_opts { - if opt == enabled_opt { - enable = false - continue - } - } - if enable { - enable_opts = append(enable_opts, opt) - } - break - } - } - if !valid { - return fmt.Errorf("option %s doesn't exist", opt) - - } - } - } - - c, _ := apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Server.OnlineClient.Credentials.Login, - Password: password, - Scenarios: scenarios, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - VersionPrefix: "v3", - }) - resp, err := c.Auth.EnrollWatcher(context.Background(), args[0], name, tags, overwrite) - if err != nil { - return fmt.Errorf("could not enroll instance: %s", err) - } - if resp.Response.StatusCode == 200 && !overwrite { - log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll") - return nil - } - - if err := SetConsoleOpts(enable_opts, true); err != nil { - return err - } - - for _, opt := range enable_opts { - log.Infof("Enabled %s : %s", opt, csconfig.CONSOLE_CONFIGS_HELP[opt]) - } - log.Info("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.") - log.Info("Please restart crowdsec after accepting the enrollment.") - return nil - }, - } - cmdEnroll.Flags().StringVarP(&name, "name", "n", "", "Name to display in the console") - cmdEnroll.Flags().BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance") - cmdEnroll.Flags().StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console") - cmdEnroll.Flags().StringSliceVarP(&opts, "enable", "e", opts, "Enable console options") - cmdConsole.AddCommand(cmdEnroll) - - var enableAll, disableAll bool - - cmdEnable := &cobra.Command{ - Use: "enable [option]", - Short: "Enable a console option", - Example: "sudo cscli console enable tainted", - Long: ` -Enable given information push to the central API. Allows to empower the console`, - ValidArgs: csconfig.CONSOLE_CONFIGS, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if enableAll { - if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil { - return err - } - log.Infof("All features have been enabled successfully") - } else { - if len(args) == 0 { - return fmt.Errorf("you must specify at least one feature to enable") - } - if err := SetConsoleOpts(args, true); err != nil { - return err - } - log.Infof("%v have been enabled", args) - } - log.Infof(ReloadMessage()) - return nil - }, - } - cmdEnable.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options") - cmdConsole.AddCommand(cmdEnable) - - cmdDisable := &cobra.Command{ - Use: "disable [option]", - Short: "Disable a console option", - Example: "sudo cscli console disable tainted", - Long: ` -Disable given information push to the central API.`, - ValidArgs: csconfig.CONSOLE_CONFIGS, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if disableAll { - if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil { - return err - } - log.Infof("All features have been disabled") - } else { - if err := SetConsoleOpts(args, false); err != nil { - return err - } - log.Infof("%v have been disabled", args) - } - - log.Infof(ReloadMessage()) - return nil - }, - } - cmdDisable.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options") - cmdConsole.AddCommand(cmdDisable) - - cmdConsoleStatus := &cobra.Command{ - Use: "status", - Short: "Shows status of the console options", - Example: `sudo cscli console status`, - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - switch csConfig.Cscli.Output { - case "human": - cmdConsoleStatusTable(color.Output, *csConfig) - case "json": - c := csConfig.API.Server.ConsoleConfig - out := map[string](*bool){ - csconfig.SEND_MANUAL_SCENARIOS: c.ShareManualDecisions, - csconfig.SEND_CUSTOM_SCENARIOS: c.ShareCustomScenarios, - csconfig.SEND_TAINTED_SCENARIOS: c.ShareTaintedScenarios, - csconfig.SEND_CONTEXT: c.ShareContext, - csconfig.CONSOLE_MANAGEMENT: c.ConsoleManagement, - } - data, err := json.MarshalIndent(out, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal configuration: %s", err) - } - fmt.Println(string(data)) - case "raw": - csvwriter := csv.NewWriter(os.Stdout) - err := csvwriter.Write([]string{"option", "enabled"}) - if err != nil { - return err - } - - rows := [][]string{ - {csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareManualDecisions)}, - {csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios)}, - {csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios)}, - {csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareContext)}, - {csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ConsoleManagement)}, - } - for _, row := range rows { - err = csvwriter.Write(row) - if err != nil { - return err - } - } - csvwriter.Flush() - } - return nil - }, - } - cmdConsole.AddCommand(cmdConsoleStatus) - - return cmdConsole -} - -func dumpConsoleConfig(c *csconfig.LocalApiServerCfg) error { - out, err := yaml.Marshal(c.ConsoleConfig) - if err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleConfigPath, err) - } - - if c.ConsoleConfigPath == "" { - c.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath - log.Debugf("Empty console_path, defaulting to %s", c.ConsoleConfigPath) - } - - if err := os.WriteFile(c.ConsoleConfigPath, out, 0o600); err != nil { - return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleConfigPath, err) - } - - return nil -} - -func SetConsoleOpts(args []string, wanted bool) error { - for _, arg := range args { - switch arg { - case csconfig.CONSOLE_MANAGEMENT: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ConsoleManagement != nil { - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement == wanted { - log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - } else { - log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - *csConfig.API.Server.ConsoleConfig.ConsoleManagement = wanted - } - } else { - log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - csConfig.API.Server.ConsoleConfig.ConsoleManagement = ptr.Of(wanted) - } - - if csConfig.API.Server.OnlineClient.Credentials != nil { - changed := false - if wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL == "" { - changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL - } else if !wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL != "" { - changed = true - csConfig.API.Server.OnlineClient.Credentials.PapiURL = "" - } - - if changed { - fileContent, err := yaml.Marshal(csConfig.API.Server.OnlineClient.Credentials) - if err != nil { - return fmt.Errorf("cannot marshal credentials: %s", err) - } - - log.Infof("Updating credentials file: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - - err = os.WriteFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600) - if err != nil { - return fmt.Errorf("cannot write credentials file: %s", err) - } - } - } - case csconfig.SEND_CUSTOM_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareCustomScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = ptr.Of(wanted) - } - case csconfig.SEND_TAINTED_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios != nil { - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = ptr.Of(wanted) - } - case csconfig.SEND_MANUAL_SCENARIOS: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareManualDecisions != nil { - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - *csConfig.API.Server.ConsoleConfig.ShareManualDecisions = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - csConfig.API.Server.ConsoleConfig.ShareManualDecisions = ptr.Of(wanted) - } - case csconfig.SEND_CONTEXT: - /*for each flag check if it's already set before setting it*/ - if csConfig.API.Server.ConsoleConfig.ShareContext != nil { - if *csConfig.API.Server.ConsoleConfig.ShareContext == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - *csConfig.API.Server.ConsoleConfig.ShareContext = wanted - } - } else { - log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - csConfig.API.Server.ConsoleConfig.ShareContext = ptr.Of(wanted) - } - default: - return fmt.Errorf("unknown flag %s", arg) - } - } - - if err := dumpConsoleConfig(csConfig.API.Server); err != nil { - return fmt.Errorf("failed writing console config: %s", err) - } - - return nil -} diff --git a/cmd/crowdsec-cli/console_table.go b/cmd/crowdsec-cli/console_table.go deleted file mode 100644 index 2a221e36f07..00000000000 --- a/cmd/crowdsec-cli/console_table.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "io" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/csconfig" -) - -func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) { - t := newTable(out) - t.SetRowLines(false) - - t.SetHeaders("Option Name", "Activated", "Description") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, option := range csconfig.CONSOLE_CONFIGS { - activated := string(emoji.CrossMark) - switch option { - case csconfig.SEND_CUSTOM_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios { - activated = string(emoji.CheckMarkButton) - } - case csconfig.SEND_MANUAL_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions { - activated = string(emoji.CheckMarkButton) - } - case csconfig.SEND_TAINTED_SCENARIOS: - if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios { - activated = string(emoji.CheckMarkButton) - } - case csconfig.SEND_CONTEXT: - if *csConfig.API.Server.ConsoleConfig.ShareContext { - activated = string(emoji.CheckMarkButton) - } - case csconfig.CONSOLE_MANAGEMENT: - if *csConfig.API.Server.ConsoleConfig.ConsoleManagement { - activated = string(emoji.CheckMarkButton) - } - } - t.AddRow(option, activated, csconfig.CONSOLE_CONFIGS_HELP[option]) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/copyfile.go b/cmd/crowdsec-cli/copyfile.go index 332f744be80..272fb3f7851 100644 --- a/cmd/crowdsec-cli/copyfile.go +++ b/cmd/crowdsec-cli/copyfile.go @@ -9,7 +9,6 @@ import ( log "github.com/sirupsen/logrus" ) - /*help to copy the file, ioutil doesn't offer the feature*/ func copyFileContents(src, dst string) (err error) { @@ -69,6 +68,7 @@ func CopyFile(sourceSymLink, destinationFile string) error { if !(destinationFileStat.Mode().IsRegular()) { return fmt.Errorf("copyFile: non-regular destination file %s (%q)", destinationFileStat.Name(), destinationFileStat.Mode().String()) } + if os.SameFile(sourceFileStat, destinationFileStat) { return err } @@ -80,4 +80,3 @@ func CopyFile(sourceSymLink, destinationFile string) error { return err } - diff --git a/cmd/crowdsec-cli/cstable/cstable.go b/cmd/crowdsec-cli/cstable/cstable.go new file mode 100644 index 00000000000..85ba491f4e8 --- /dev/null +++ b/cmd/crowdsec-cli/cstable/cstable.go @@ -0,0 +1,161 @@ +package cstable + +// transisional file to keep (minimal) backwards compatibility with the old table +// we can migrate the code to the new dependency later, it can already use the Writer interface + +import ( + "fmt" + "io" + "os" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/jedib0t/go-pretty/v6/text" + isatty "github.com/mattn/go-isatty" +) + +func shouldWeColorize(wantColor string) bool { + switch wantColor { + case "yes": + return true + case "no": + return false + default: + return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) + } +} + +type Table struct { + Writer table.Writer + output io.Writer + align []text.Align + alignHeader []text.Align +} + +func New(out io.Writer, wantColor string) *Table { + if out == nil { + panic("newTable: out is nil") + } + + t := table.NewWriter() + + // colorize output, use unicode box characters + fancy := shouldWeColorize(wantColor) + + colorOptions := table.ColorOptions{} + + if fancy { + colorOptions.Header = text.Colors{text.Italic} + colorOptions.Border = text.Colors{text.FgHiBlack} + colorOptions.Separator = text.Colors{text.FgHiBlack} + } + + // no upper/lower case transformations + format := table.FormatOptions{} + + box := table.StyleBoxDefault + if fancy { + box = table.StyleBoxRounded + } + + style := table.Style{ + Box: box, + Color: colorOptions, + Format: format, + HTML: table.DefaultHTMLOptions, + Options: table.OptionsDefault, + Title: table.TitleOptionsDefault, + } + + t.SetStyle(style) + + return &Table{ + Writer: t, + output: out, + align: make([]text.Align, 0), + alignHeader: make([]text.Align, 0), + } +} + +func NewLight(output io.Writer, wantColor string) *Table { + t := New(output, wantColor) + s := t.Writer.Style() + s.Box.Left = "" + s.Box.LeftSeparator = "" + s.Box.TopLeft = "" + s.Box.BottomLeft = "" + s.Box.Right = "" + s.Box.RightSeparator = "" + s.Box.TopRight = "" + s.Box.BottomRight = "" + s.Options.SeparateRows = false + s.Options.SeparateFooter = false + s.Options.SeparateHeader = true + s.Options.SeparateColumns = false + + return t +} + +// +// wrapper methods for backwards compatibility +// + +// setColumnConfigs must be called right before rendering, +// to allow for setting the alignment like the old API +func (t *Table) setColumnConfigs() { + configs := []table.ColumnConfig{} + // the go-pretty table does not expose the names or number of columns + for i := range len(t.align) { + configs = append(configs, table.ColumnConfig{ + Number: i + 1, + AlignHeader: t.alignHeader[i], + Align: t.align[i], + WidthMax: 60, + WidthMaxEnforcer: text.WrapSoft, + }) + } + + t.Writer.SetColumnConfigs(configs) +} + +func (t *Table) Render() { + // change default options for backwards compatibility. + // we do this late to allow changing the alignment like the old API + t.setColumnConfigs() + fmt.Fprintln(t.output, t.Writer.Render()) +} + +func (t *Table) SetHeaders(str ...string) { + row := table.Row{} + t.align = make([]text.Align, len(str)) + t.alignHeader = make([]text.Align, len(str)) + + for i, v := range str { + row = append(row, v) + t.align[i] = text.AlignLeft + t.alignHeader[i] = text.AlignCenter + } + + t.Writer.AppendHeader(row) +} + +func (t *Table) AddRow(str ...string) { + row := table.Row{} + for _, v := range str { + row = append(row, v) + } + + t.Writer.AppendRow(row) +} + +func (t *Table) SetRowLines(rowLines bool) { + t.Writer.Style().Options.SeparateRows = rowLines +} + +func (t *Table) SetAlignment(align ...text.Align) { + // align can be shorter than t.align, it will leave the default value + copy(t.align, align) +} + +func (t *Table) SetHeaderAlignment(align ...text.Align) { + copy(t.alignHeader, align) +} diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go index a3701c4dbbb..41db9e6cbf2 100644 --- a/cmd/crowdsec-cli/dashboard.go +++ b/cmd/crowdsec-cli/dashboard.go @@ -3,6 +3,7 @@ package main import ( + "errors" "fmt" "math" "os" @@ -19,15 +20,17 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/crowdsecurity/crowdsec/pkg/metabase" + "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/metabase" ) var ( metabaseUser = "crowdsec@crowdsec.net" metabasePassword string - metabaseDbPath string + metabaseDBPath string metabaseConfigPath string metabaseConfigFolder = "metabase/" metabaseConfigFile = "metabase.yaml" @@ -43,14 +46,17 @@ var ( // information needed to set up a random password on user's behalf ) -type cliDashboard struct{} +type cliDashboard struct { + cfg configGetter +} -func NewCLIDashboard() *cliDashboard { - return &cliDashboard{} +func NewCLIDashboard(cfg configGetter) *cliDashboard { + return &cliDashboard{ + cfg: cfg, + } } -func (cli cliDashboard) NewCommand() *cobra.Command { - /* ---- UPDATE COMMAND */ +func (cli *cliDashboard) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "dashboard [command]", Short: "Manage your metabase dashboard container [requires local API]", @@ -65,8 +71,13 @@ cscli dashboard start cscli dashboard stop cscli dashboard remove `, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if version.System == "docker" { + return errors.New("cscli dashboard is not supported whilst running CrowdSec within a container please see: https://github.com/crowdsecurity/example-docker-compose/tree/main/basic") + } + + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } @@ -74,13 +85,13 @@ cscli dashboard remove return err } - metabaseConfigFolderPath := filepath.Join(csConfig.ConfigPaths.ConfigDir, metabaseConfigFolder) + metabaseConfigFolderPath := filepath.Join(cfg.ConfigPaths.ConfigDir, metabaseConfigFolder) metabaseConfigPath = filepath.Join(metabaseConfigFolderPath, metabaseConfigFile) if err := os.MkdirAll(metabaseConfigFolderPath, os.ModePerm); err != nil { return err } - if err := require.DB(csConfig); err != nil { + if err := require.DB(cfg); err != nil { return err } @@ -95,20 +106,23 @@ cscli dashboard remove metabaseContainerID = oldContainerID } } + + log.Warn("cscli dashboard will be deprecated in version 1.7.0, read more at https://docs.crowdsec.net/blog/cscli_dashboard_deprecation/") + return nil }, } - cmd.AddCommand(cli.NewSetupCmd()) - cmd.AddCommand(cli.NewStartCmd()) - cmd.AddCommand(cli.NewStopCmd()) - cmd.AddCommand(cli.NewShowPasswordCmd()) - cmd.AddCommand(cli.NewRemoveCmd()) + cmd.AddCommand(cli.newSetupCmd()) + cmd.AddCommand(cli.newStartCmd()) + cmd.AddCommand(cli.newStopCmd()) + cmd.AddCommand(cli.newShowPasswordCmd()) + cmd.AddCommand(cli.newRemoveCmd()) return cmd } -func (cli cliDashboard) NewSetupCmd() *cobra.Command { +func (cli *cliDashboard) newSetupCmd() *cobra.Command { var force bool cmd := &cobra.Command{ @@ -122,15 +136,15 @@ cscli dashboard setup cscli dashboard setup --listen 0.0.0.0 cscli dashboard setup -l 0.0.0.0 -p 443 --password `, - RunE: func(cmd *cobra.Command, args []string) error { - if metabaseDbPath == "" { - metabaseDbPath = csConfig.ConfigPaths.DataDir + RunE: func(_ *cobra.Command, _ []string) error { + if metabaseDBPath == "" { + metabaseDBPath = cli.cfg().ConfigPaths.DataDir } if metabasePassword == "" { isValid := passwordIsValid(metabasePassword) for !isValid { - metabasePassword = generatePassword(16) + metabasePassword = idgen.GeneratePassword(16) isValid = passwordIsValid(metabasePassword) } } @@ -145,10 +159,10 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password if err != nil { return err } - if err = chownDatabase(dockerGroup.Gid); err != nil { + if err = cli.chownDatabase(dockerGroup.Gid); err != nil { return err } - mb, err := metabase.SetupMetabase(csConfig.API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDbPath, dockerGroup.Gid, metabaseContainerID, metabaseImage) + mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDBPath, dockerGroup.Gid, metabaseContainerID, metabaseImage) if err != nil { return err } @@ -161,29 +175,32 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password fmt.Printf("\tURL : '%s'\n", mb.Config.ListenURL) fmt.Printf("\tusername : '%s'\n", mb.Config.Username) fmt.Printf("\tpassword : '%s'\n", mb.Config.Password) + return nil }, } - cmd.Flags().BoolVarP(&force, "force", "f", false, "Force setup : override existing files") - cmd.Flags().StringVarP(&metabaseDbPath, "dir", "d", "", "Shared directory with metabase container") - cmd.Flags().StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container") - cmd.Flags().StringVar(&metabaseImage, "metabase-image", metabaseImage, "Metabase image to use") - cmd.Flags().StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container") - cmd.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") - //cmd.Flags().StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user") - cmd.Flags().StringVar(&metabasePassword, "password", "", "metabase password") + + flags := cmd.Flags() + flags.BoolVarP(&force, "force", "f", false, "Force setup : override existing files") + flags.StringVarP(&metabaseDBPath, "dir", "d", "", "Shared directory with metabase container") + flags.StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container") + flags.StringVar(&metabaseImage, "metabase-image", metabaseImage, "Metabase image to use") + flags.StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container") + flags.BoolVarP(&forceYes, "yes", "y", false, "force yes") + // flags.StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user") + flags.StringVar(&metabasePassword, "password", "", "metabase password") return cmd } -func (cli cliDashboard) NewStartCmd() *cobra.Command { +func (cli *cliDashboard) newStartCmd() *cobra.Command { cmd := &cobra.Command{ Use: "start", Short: "Start the metabase container.", Long: `Stats the metabase container using docker.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { mb, err := metabase.NewMetabase(metabaseConfigPath, metabaseContainerID) if err != nil { return err @@ -197,22 +214,24 @@ func (cli cliDashboard) NewStartCmd() *cobra.Command { } log.Infof("Started metabase") log.Infof("url : http://%s:%s", mb.Config.ListenAddr, mb.Config.ListenPort) + return nil }, } + cmd.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") return cmd } -func (cli cliDashboard) NewStopCmd() *cobra.Command { +func (cli *cliDashboard) newStopCmd() *cobra.Command { cmd := &cobra.Command{ Use: "stop", Short: "Stops the metabase container.", Long: `Stops the metabase container using docker.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { if err := metabase.StopContainer(metabaseContainerID); err != nil { return fmt.Errorf("unable to stop container '%s': %s", metabaseContainerID, err) } @@ -223,17 +242,18 @@ func (cli cliDashboard) NewStopCmd() *cobra.Command { return cmd } -func (cli cliDashboard) NewShowPasswordCmd() *cobra.Command { +func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command { cmd := &cobra.Command{Use: "show-password", Short: "displays password of metabase.", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { m := metabase.Metabase{} if err := m.LoadConfig(metabaseConfigPath); err != nil { return err } log.Printf("'%s'", m.Config.Password) + return nil }, } @@ -241,7 +261,7 @@ func (cli cliDashboard) NewShowPasswordCmd() *cobra.Command { return cmd } -func (cli cliDashboard) NewRemoveCmd() *cobra.Command { +func (cli *cliDashboard) newRemoveCmd() *cobra.Command { var force bool cmd := &cobra.Command{ @@ -254,7 +274,7 @@ func (cli cliDashboard) NewRemoveCmd() *cobra.Command { cscli dashboard remove cscli dashboard remove --force `, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { if !forceYes { var answer bool prompt := &survey.Confirm{ @@ -265,7 +285,7 @@ cscli dashboard remove --force return fmt.Errorf("unable to ask to force: %s", err) } if !answer { - return fmt.Errorf("user stated no to continue") + return errors.New("user stated no to continue") } } if metabase.IsContainerExist(metabaseContainerID) { @@ -277,7 +297,7 @@ cscli dashboard remove --force if err == nil { // if group exist, remove it groupDelCmd, err := exec.LookPath("groupdel") if err != nil { - return fmt.Errorf("unable to find 'groupdel' command, can't continue") + return errors.New("unable to find 'groupdel' command, can't continue") } groupDel := &exec.Cmd{Path: groupDelCmd, Args: []string{groupDelCmd, crowdsecGroup}} @@ -291,8 +311,8 @@ cscli dashboard remove --force } log.Infof("container %s stopped & removed", metabaseContainerID) } - log.Debugf("Removing metabase db %s", csConfig.ConfigPaths.DataDir) - if err := metabase.RemoveDatabase(csConfig.ConfigPaths.DataDir); err != nil { + log.Debugf("Removing metabase db %s", cli.cfg().ConfigPaths.DataDir) + if err := metabase.RemoveDatabase(cli.cfg().ConfigPaths.DataDir); err != nil { log.Warnf("failed to remove metabase internal db : %s", err) } if force { @@ -306,11 +326,14 @@ cscli dashboard remove --force } } } + return nil }, } - cmd.Flags().BoolVarP(&force, "force", "f", false, "Remove also the metabase image") - cmd.Flags().BoolVarP(&forceYes, "yes", "y", false, "force yes") + + flags := cmd.Flags() + flags.BoolVarP(&force, "force", "f", false, "Remove also the metabase image") + flags.BoolVarP(&forceYes, "yes", "y", false, "force yes") return cmd } @@ -351,7 +374,7 @@ func checkSystemMemory(forceYes *bool) error { } if !answer { - return fmt.Errorf("user stated no to continue") + return errors.New("user stated no to continue") } return nil @@ -384,7 +407,7 @@ func disclaimer(forceYes *bool) error { } if !answer { - return fmt.Errorf("user stated no to responsibilities") + return errors.New("user stated no to responsibilities") } return nil @@ -420,7 +443,7 @@ func checkGroups(forceYes *bool) (*user.Group, error) { groupAddCmd, err := exec.LookPath("groupadd") if err != nil { - return dockerGroup, fmt.Errorf("unable to find 'groupadd' command, can't continue") + return dockerGroup, errors.New("unable to find 'groupadd' command, can't continue") } groupAdd := &exec.Cmd{Path: groupAddCmd, Args: []string{groupAddCmd, crowdsecGroup}} @@ -431,22 +454,24 @@ func checkGroups(forceYes *bool) (*user.Group, error) { return user.LookupGroup(crowdsecGroup) } -func chownDatabase(gid string) error { +func (cli *cliDashboard) chownDatabase(gid string) error { + cfg := cli.cfg() intID, err := strconv.Atoi(gid) + if err != nil { return fmt.Errorf("unable to convert group ID to int: %s", err) } - if stat, err := os.Stat(csConfig.DbConfig.DbPath); !os.IsNotExist(err) { + if stat, err := os.Stat(cfg.DbConfig.DbPath); !os.IsNotExist(err) { info := stat.Sys() - if err := os.Chown(csConfig.DbConfig.DbPath, int(info.(*syscall.Stat_t).Uid), intID); err != nil { - return fmt.Errorf("unable to chown sqlite db file '%s': %s", csConfig.DbConfig.DbPath, err) + if err := os.Chown(cfg.DbConfig.DbPath, int(info.(*syscall.Stat_t).Uid), intID); err != nil { + return fmt.Errorf("unable to chown sqlite db file '%s': %s", cfg.DbConfig.DbPath, err) } } - if csConfig.DbConfig.Type == "sqlite" && csConfig.DbConfig.UseWal != nil && *csConfig.DbConfig.UseWal { + if cfg.DbConfig.Type == "sqlite" && cfg.DbConfig.UseWal != nil && *cfg.DbConfig.UseWal { for _, ext := range []string{"-wal", "-shm"} { - file := csConfig.DbConfig.DbPath + ext + file := cfg.DbConfig.DbPath + ext if stat, err := os.Stat(file); !os.IsNotExist(err) { info := stat.Sys() if err := os.Chown(file, int(info.(*syscall.Stat_t).Uid), intID); err != nil { diff --git a/cmd/crowdsec-cli/dashboard_unsupported.go b/cmd/crowdsec-cli/dashboard_unsupported.go index 072ff525b19..cc80abd2528 100644 --- a/cmd/crowdsec-cli/dashboard_unsupported.go +++ b/cmd/crowdsec-cli/dashboard_unsupported.go @@ -9,17 +9,21 @@ import ( "github.com/spf13/cobra" ) -type cliDashboard struct{} +type cliDashboard struct{ + cfg configGetter +} -func NewCLIDashboard() *cliDashboard { - return &cliDashboard{} +func NewCLIDashboard(cfg configGetter) *cliDashboard { + return &cliDashboard{ + cfg: cfg, + } } func (cli cliDashboard) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "dashboard", DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { + Run: func(_ *cobra.Command, _ []string) { log.Infof("Dashboard command is disabled on %s", runtime.GOOS) }, } diff --git a/cmd/crowdsec-cli/decisions.go b/cmd/crowdsec-cli/decisions.go deleted file mode 100644 index 683f100d4f7..00000000000 --- a/cmd/crowdsec-cli/decisions.go +++ /dev/null @@ -1,510 +0,0 @@ -package main - -import ( - "context" - "encoding/csv" - "encoding/json" - "fmt" - "net/url" - "os" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/go-cs-lib/version" - - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -var Client *apiclient.ApiClient - -func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error { - /*here we cheat a bit : to make it more readable for the user, we dedup some entries*/ - spamLimit := make(map[string]bool) - skipped := 0 - - for aIdx := 0; aIdx < len(*alerts); aIdx++ { - alertItem := (*alerts)[aIdx] - newDecisions := make([]*models.Decision, 0) - - for _, decisionItem := range alertItem.Decisions { - spamKey := fmt.Sprintf("%t:%s:%s:%s", *decisionItem.Simulated, *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) - if _, ok := spamLimit[spamKey]; ok { - skipped++ - continue - } - - spamLimit[spamKey] = true - - newDecisions = append(newDecisions, decisionItem) - } - - alertItem.Decisions = newDecisions - } - - if csConfig.Cscli.Output == "raw" { - csvwriter := csv.NewWriter(os.Stdout) - header := []string{"id", "source", "ip", "reason", "action", "country", "as", "events_count", "expiration", "simulated", "alert_id"} - - if printMachine { - header = append(header, "machine") - } - - err := csvwriter.Write(header) - if err != nil { - return err - } - - for _, alertItem := range *alerts { - for _, decisionItem := range alertItem.Decisions { - raw := []string{ - fmt.Sprintf("%d", decisionItem.ID), - *decisionItem.Origin, - *decisionItem.Scope + ":" + *decisionItem.Value, - *decisionItem.Scenario, - *decisionItem.Type, - alertItem.Source.Cn, - alertItem.Source.GetAsNumberName(), - fmt.Sprintf("%d", *alertItem.EventsCount), - *decisionItem.Duration, - fmt.Sprintf("%t", *decisionItem.Simulated), - fmt.Sprintf("%d", alertItem.ID), - } - if printMachine { - raw = append(raw, alertItem.MachineID) - } - - err := csvwriter.Write(raw) - if err != nil { - return err - } - } - } - - csvwriter.Flush() - } else if csConfig.Cscli.Output == "json" { - if *alerts == nil { - // avoid returning "null" in `json" - // could be cleaner if we used slice of alerts directly - fmt.Println("[]") - return nil - } - x, _ := json.MarshalIndent(alerts, "", " ") - fmt.Printf("%s", string(x)) - } else if csConfig.Cscli.Output == "human" { - if len(*alerts) == 0 { - fmt.Println("No active decisions") - return nil - } - decisionsTable(color.Output, alerts, printMachine) - if skipped > 0 { - fmt.Printf("%d duplicated entries skipped\n", skipped) - } - } - - return nil -} - - -type cliDecisions struct {} - -func NewCLIDecisions() *cliDecisions { - return &cliDecisions{} -} - -func (cli cliDecisions) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "decisions [action]", - Short: "Manage decisions", - Long: `Add/List/Delete/Import decisions from LAPI`, - Example: `cscli decisions [action] [filter]`, - Aliases: []string{"decision"}, - /*TBD example*/ - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - if err := csConfig.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - password := strfmt.Password(csConfig.API.Client.Credentials.Password) - apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url %s: %w", csConfig.API.Client.Credentials.URL, err) - } - Client, err = apiclient.NewClient(&apiclient.Config{ - MachineID: csConfig.API.Client.Credentials.Login, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: "v1", - }) - if err != nil { - return fmt.Errorf("creating api client: %w", err) - } - return nil - }, - } - - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewAddCmd()) - cmd.AddCommand(cli.NewDeleteCmd()) - cmd.AddCommand(cli.NewImportCmd()) - - return cmd -} - -func (cli cliDecisions) NewListCmd() *cobra.Command { - var filter = apiclient.AlertsListOpts{ - ValueEquals: new(string), - ScopeEquals: new(string), - ScenarioEquals: new(string), - OriginEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - Since: new(string), - Until: new(string), - TypeEquals: new(string), - IncludeCAPI: new(bool), - Limit: new(int), - } - - NoSimu := new(bool) - contained := new(bool) - - var printMachine bool - - cmd := &cobra.Command{ - Use: "list [options]", - Short: "List decisions from LAPI", - Example: `cscli decisions list -i 1.2.3.4 -cscli decisions list -r 1.2.3.0/24 -cscli decisions list -s crowdsecurity/ssh-bf -cscli decisions list -t ban -`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, _ []string) error { - var err error - /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(filter.IPEquals, filter.RangeEquals, filter.ScopeEquals, filter.ValueEquals); err != nil { - return err - } - filter.ActiveDecisionEquals = new(bool) - *filter.ActiveDecisionEquals = true - if NoSimu != nil && *NoSimu { - filter.IncludeSimulated = new(bool) - } - /* nullify the empty entries to avoid bad filter */ - if *filter.Until == "" { - filter.Until = nil - } else if strings.HasSuffix(*filter.Until, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*filter.Until, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Until) - } - *filter.Until = fmt.Sprintf("%d%s", days*24, "h") - } - - if *filter.Since == "" { - filter.Since = nil - } else if strings.HasSuffix(*filter.Since, "d") { - /*time.ParseDuration support hours 'h' as bigger unit, let's make the user's life easier*/ - realDuration := strings.TrimSuffix(*filter.Since, "d") - days, err := strconv.Atoi(realDuration) - if err != nil { - printHelp(cmd) - return fmt.Errorf("can't parse duration %s, valid durations format: 1d, 4h, 4h15m", *filter.Since) - } - *filter.Since = fmt.Sprintf("%d%s", days*24, "h") - } - if *filter.IncludeCAPI { - *filter.Limit = 0 - } - if *filter.TypeEquals == "" { - filter.TypeEquals = nil - } - if *filter.ValueEquals == "" { - filter.ValueEquals = nil - } - if *filter.ScopeEquals == "" { - filter.ScopeEquals = nil - } - if *filter.ScenarioEquals == "" { - filter.ScenarioEquals = nil - } - if *filter.IPEquals == "" { - filter.IPEquals = nil - } - if *filter.RangeEquals == "" { - filter.RangeEquals = nil - } - - if *filter.OriginEquals == "" { - filter.OriginEquals = nil - } - - if contained != nil && *contained { - filter.Contains = new(bool) - } - - alerts, _, err := Client.Alerts.List(context.Background(), filter) - if err != nil { - return fmt.Errorf("unable to retrieve decisions: %w", err) - } - - err = DecisionsToTable(alerts, printMachine) - if err != nil { - return fmt.Errorf("unable to print decisions: %w", err) - } - - return nil - }, - } - cmd.Flags().SortFlags = false - cmd.Flags().BoolVarP(filter.IncludeCAPI, "all", "a", false, "Include decisions from Central API") - cmd.Flags().StringVar(filter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)") - cmd.Flags().StringVar(filter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)") - cmd.Flags().StringVarP(filter.TypeEquals, "type", "t", "", "restrict to this decision type (ie. ban,captcha)") - cmd.Flags().StringVar(filter.ScopeEquals, "scope", "", "restrict to this scope (ie. ip,range,session)") - cmd.Flags().StringVar(filter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - cmd.Flags().StringVarP(filter.ValueEquals, "value", "v", "", "restrict to this value (ie. 1.2.3.4,userName)") - cmd.Flags().StringVarP(filter.ScenarioEquals, "scenario", "s", "", "restrict to this scenario (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVarP(filter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(filter.RangeEquals, "range", "r", "", "restrict to alerts from this source range (shorthand for --scope range --value )") - cmd.Flags().IntVarP(filter.Limit, "limit", "l", 100, "number of alerts to get (use 0 to remove the limit)") - cmd.Flags().BoolVar(NoSimu, "no-simu", false, "exclude decisions in simulation mode") - cmd.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that triggered decisions") - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - - return cmd -} - -func (cli cliDecisions) NewAddCmd() *cobra.Command { - var ( - addIP string - addRange string - addDuration string - addValue string - addScope string - addReason string - addType string - ) - - cmd := &cobra.Command{ - Use: "add [options]", - Short: "Add decision to LAPI", - Example: `cscli decisions add --ip 1.2.3.4 -cscli decisions add --range 1.2.3.0/24 -cscli decisions add --ip 1.2.3.4 --duration 24h --type captcha -cscli decisions add --scope username --value foobar -`, - /*TBD : fix long and example*/ - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, _ []string) error { - var err error - alerts := models.AddAlertsRequest{} - origin := types.CscliOrigin - capacity := int32(0) - leakSpeed := "0" - eventsCount := int32(1) - empty := "" - simulated := false - startAt := time.Now().UTC().Format(time.RFC3339) - stopAt := time.Now().UTC().Format(time.RFC3339) - createdAt := time.Now().UTC().Format(time.RFC3339) - - /*take care of shorthand options*/ - if err := manageCliDecisionAlerts(&addIP, &addRange, &addScope, &addValue); err != nil { - return err - } - - if addIP != "" { - addValue = addIP - addScope = types.Ip - } else if addRange != "" { - addValue = addRange - addScope = types.Range - } else if addValue == "" { - printHelp(cmd) - return fmt.Errorf("missing arguments, a value is required (--ip, --range or --scope and --value)") - } - - if addReason == "" { - addReason = fmt.Sprintf("manual '%s' from '%s'", addType, csConfig.API.Client.Credentials.Login) - } - decision := models.Decision{ - Duration: &addDuration, - Scope: &addScope, - Value: &addValue, - Type: &addType, - Scenario: &addReason, - Origin: &origin, - } - alert := models.Alert{ - Capacity: &capacity, - Decisions: []*models.Decision{&decision}, - Events: []*models.Event{}, - EventsCount: &eventsCount, - Leakspeed: &leakSpeed, - Message: &addReason, - ScenarioHash: &empty, - Scenario: &addReason, - ScenarioVersion: &empty, - Simulated: &simulated, - //setting empty scope/value broke plugins, and it didn't seem to be needed anymore w/ latest papi changes - Source: &models.Source{ - AsName: empty, - AsNumber: empty, - Cn: empty, - IP: addValue, - Range: "", - Scope: &addScope, - Value: &addValue, - }, - StartAt: &startAt, - StopAt: &stopAt, - CreatedAt: createdAt, - } - alerts = append(alerts, &alert) - - _, _, err = Client.Alerts.Add(context.Background(), alerts) - if err != nil { - return err - } - - log.Info("Decision successfully added") - return nil - }, - } - - cmd.Flags().SortFlags = false - cmd.Flags().StringVarP(&addIP, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(&addRange, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmd.Flags().StringVarP(&addDuration, "duration", "d", "4h", "Decision duration (ie. 1h,4h,30m)") - cmd.Flags().StringVarP(&addValue, "value", "v", "", "The value (ie. --scope username --value foobar)") - cmd.Flags().StringVar(&addScope, "scope", types.Ip, "Decision scope (ie. ip,range,username)") - cmd.Flags().StringVarP(&addReason, "reason", "R", "", "Decision reason (ie. scenario-name)") - cmd.Flags().StringVarP(&addType, "type", "t", "ban", "Decision type (ie. ban,captcha,throttle)") - - return cmd -} - -func (cli cliDecisions) NewDeleteCmd() *cobra.Command { - var delFilter = apiclient.DecisionsDeleteOpts{ - ScopeEquals: new(string), - ValueEquals: new(string), - TypeEquals: new(string), - IPEquals: new(string), - RangeEquals: new(string), - ScenarioEquals: new(string), - OriginEquals: new(string), - } - - var delDecisionID string - - var delDecisionAll bool - - contained := new(bool) - - cmd := &cobra.Command{ - Use: "delete [options]", - Short: "Delete decisions", - DisableAutoGenTag: true, - Aliases: []string{"remove"}, - Example: `cscli decisions delete -r 1.2.3.0/24 -cscli decisions delete -i 1.2.3.4 -cscli decisions delete --id 42 -cscli decisions delete --type captcha -`, - /*TBD : refaire le Long/Example*/ - PreRunE: func(cmd *cobra.Command, _ []string) error { - if delDecisionAll { - return nil - } - if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && - *delFilter.TypeEquals == "" && *delFilter.IPEquals == "" && - *delFilter.RangeEquals == "" && *delFilter.ScenarioEquals == "" && - *delFilter.OriginEquals == "" && delDecisionID == "" { - cmd.Usage() - return fmt.Errorf("at least one filter or --all must be specified") - } - - return nil - }, - RunE: func(_ *cobra.Command, _ []string) error { - var err error - var decisions *models.DeleteDecisionResponse - - /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(delFilter.IPEquals, delFilter.RangeEquals, delFilter.ScopeEquals, delFilter.ValueEquals); err != nil { - return err - } - if *delFilter.ScopeEquals == "" { - delFilter.ScopeEquals = nil - } - if *delFilter.OriginEquals == "" { - delFilter.OriginEquals = nil - } - if *delFilter.ValueEquals == "" { - delFilter.ValueEquals = nil - } - if *delFilter.ScenarioEquals == "" { - delFilter.ScenarioEquals = nil - } - if *delFilter.TypeEquals == "" { - delFilter.TypeEquals = nil - } - if *delFilter.IPEquals == "" { - delFilter.IPEquals = nil - } - if *delFilter.RangeEquals == "" { - delFilter.RangeEquals = nil - } - if contained != nil && *contained { - delFilter.Contains = new(bool) - } - - if delDecisionID == "" { - decisions, _, err = Client.Decisions.Delete(context.Background(), delFilter) - if err != nil { - return fmt.Errorf("unable to delete decisions: %v", err) - } - } else { - if _, err = strconv.Atoi(delDecisionID); err != nil { - return fmt.Errorf("id '%s' is not an integer: %v", delDecisionID, err) - } - decisions, _, err = Client.Decisions.DeleteOne(context.Background(), delDecisionID) - if err != nil { - return fmt.Errorf("unable to delete decision: %v", err) - } - } - log.Infof("%s decision(s) deleted", decisions.NbDeleted) - return nil - }, - } - - cmd.Flags().SortFlags = false - cmd.Flags().StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - cmd.Flags().StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") - cmd.Flags().StringVarP(delFilter.TypeEquals, "type", "t", "", "the decision type (ie. ban,captcha)") - cmd.Flags().StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - cmd.Flags().StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario name (ie. crowdsecurity/ssh-bf)") - cmd.Flags().StringVar(delFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ","))) - - cmd.Flags().StringVar(&delDecisionID, "id", "", "decision id") - cmd.Flags().BoolVar(&delDecisionAll, "all", false, "delete all decisions") - cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range") - - return cmd -} diff --git a/cmd/crowdsec-cli/doc.go b/cmd/crowdsec-cli/doc.go index a4896f3da30..f68d535db03 100644 --- a/cmd/crowdsec-cli/doc.go +++ b/cmd/crowdsec-cli/doc.go @@ -16,20 +16,30 @@ func NewCLIDoc() *cliDoc { } func (cli cliDoc) NewCommand(rootCmd *cobra.Command) *cobra.Command { + var target string + + const defaultTarget = "./doc" + cmd := &cobra.Command{ Use: "doc", - Short: "Generate the documentation in `./doc/`. Directory must exist.", - Args: cobra.ExactArgs(0), + Short: "Generate the documentation related to cscli commands. Target directory must exist.", + Args: cobra.NoArgs, Hidden: true, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - if err := doc.GenMarkdownTreeCustom(rootCmd, "./doc/", cli.filePrepender, cli.linkHandler); err != nil { - return fmt.Errorf("failed to generate cobra doc: %s", err) + RunE: func(_ *cobra.Command, args []string) error { + if err := doc.GenMarkdownTreeCustom(rootCmd, target, cli.filePrepender, cli.linkHandler); err != nil { + return fmt.Errorf("failed to generate cscli documentation: %w", err) } + + fmt.Println("Documentation generated in", target) + return nil }, } + flags := cmd.Flags() + flags.StringVar(&target, "target", defaultTarget, "The target directory where the documentation will be generated") + return cmd } @@ -39,8 +49,10 @@ id: %s title: %s --- ` + name := filepath.Base(filename) base := strings.TrimSuffix(name, filepath.Ext(name)) + return fmt.Sprintf(header, base, strings.ReplaceAll(base, "_", " ")) } diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go deleted file mode 100644 index 1860540e7dc..00000000000 --- a/cmd/crowdsec-cli/hubtest.go +++ /dev/null @@ -1,692 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "math" - "os" - "path/filepath" - "strings" - "text/template" - - "github.com/AlecAivazis/survey/v2" - "github.com/enescakir/emoji" - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - - "github.com/crowdsecurity/crowdsec/pkg/dumps" - "github.com/crowdsecurity/crowdsec/pkg/hubtest" -) - -var HubTest hubtest.HubTest -var HubAppsecTests hubtest.HubTest -var hubPtr *hubtest.HubTest -var isAppsecTest bool - -type cliHubTest struct{} - -func NewCLIHubTest() *cliHubTest { - return &cliHubTest{} -} - -func (cli cliHubTest) NewCommand() *cobra.Command { - var hubPath string - var crowdsecPath string - var cscliPath string - - cmd := &cobra.Command{ - Use: "hubtest", - Short: "Run functional tests on hub configurations", - Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - var err error - HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, false) - if err != nil { - return fmt.Errorf("unable to load hubtest: %+v", err) - } - - HubAppsecTests, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, true) - if err != nil { - return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) - } - /*commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests*/ - hubPtr = &HubTest - if isAppsecTest { - hubPtr = &HubAppsecTests - } - return nil - }, - } - - cmd.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") - cmd.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") - cmd.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") - cmd.PersistentFlags().BoolVar(&isAppsecTest, "appsec", false, "Command relates to appsec tests") - - cmd.AddCommand(cli.NewCreateCmd()) - cmd.AddCommand(cli.NewRunCmd()) - cmd.AddCommand(cli.NewCleanCmd()) - cmd.AddCommand(cli.NewInfoCmd()) - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewCoverageCmd()) - cmd.AddCommand(cli.NewEvalCmd()) - cmd.AddCommand(cli.NewExplainCmd()) - - return cmd -} - -func (cli cliHubTest) NewCreateCmd() *cobra.Command { - parsers := []string{} - postoverflows := []string{} - scenarios := []string{} - var ignoreParsers bool - var labels map[string]string - var logType string - - cmd := &cobra.Command{ - Use: "create", - Short: "create [test_name]", - Example: `cscli hubtest create my-awesome-test --type syslog -cscli hubtest create my-nginx-custom-test --type nginx -cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - testName := args[0] - testPath := filepath.Join(hubPtr.HubTestPath, testName) - if _, err := os.Stat(testPath); os.IsExist(err) { - return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) - } - - if isAppsecTest { - logType = "appsec" - } - - if logType == "" { - return fmt.Errorf("please provide a type (--type) for the test") - } - - if err := os.MkdirAll(testPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) - } - - configFilePath := filepath.Join(testPath, "config.yaml") - - configFileData := &hubtest.HubTestItemConfig{} - if logType == "appsec" { - //create empty nuclei template file - nucleiFileName := fmt.Sprintf("%s.yaml", testName) - nucleiFilePath := filepath.Join(testPath, nucleiFileName) - nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0755) - if err != nil { - return err - } - - ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) - if ntpl == nil { - return fmt.Errorf("unable to parse nuclei template") - } - ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) - nucleiFile.Close() - configFileData.AppsecRules = []string{"./appsec-rules//your_rule_here.yaml"} - configFileData.NucleiTemplate = nucleiFileName - fmt.Println() - fmt.Printf(" Test name : %s\n", testName) - fmt.Printf(" Test path : %s\n", testPath) - fmt.Printf(" Config File : %s\n", configFilePath) - fmt.Printf(" Nuclei Template : %s\n", nucleiFilePath) - } else { - // create empty log file - logFileName := fmt.Sprintf("%s.log", testName) - logFilePath := filepath.Join(testPath, logFileName) - logFile, err := os.Create(logFilePath) - if err != nil { - return err - } - logFile.Close() - - // create empty parser assertion file - parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) - parserAssertFile, err := os.Create(parserAssertFilePath) - if err != nil { - return err - } - parserAssertFile.Close() - // create empty scenario assertion file - scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) - scenarioAssertFile, err := os.Create(scenarioAssertFilePath) - if err != nil { - return err - } - scenarioAssertFile.Close() - - parsers = append(parsers, "crowdsecurity/syslog-logs") - parsers = append(parsers, "crowdsecurity/dateparse-enrich") - - if len(scenarios) == 0 { - scenarios = append(scenarios, "") - } - - if len(postoverflows) == 0 { - postoverflows = append(postoverflows, "") - } - configFileData.Parsers = parsers - configFileData.Scenarios = scenarios - configFileData.PostOverflows = postoverflows - configFileData.LogFile = logFileName - configFileData.LogType = logType - configFileData.IgnoreParsers = ignoreParsers - configFileData.Labels = labels - fmt.Println() - fmt.Printf(" Test name : %s\n", testName) - fmt.Printf(" Test path : %s\n", testPath) - fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) - fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) - fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) - fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) - - } - - fd, err := os.Create(configFilePath) - if err != nil { - return fmt.Errorf("open: %s", err) - } - data, err := yaml.Marshal(configFileData) - if err != nil { - return fmt.Errorf("marshal: %s", err) - } - _, err = fd.Write(data) - if err != nil { - return fmt.Errorf("write: %s", err) - } - if err := fd.Close(); err != nil { - return fmt.Errorf("close: %s", err) - } - return nil - }, - } - - cmd.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") - cmd.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") - cmd.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") - cmd.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") - cmd.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") - - return cmd -} - -func (cli cliHubTest) NewRunCmd() *cobra.Command { - var noClean bool - var runAll bool - var forceClean bool - var NucleiTargetHost string - var AppSecHost string - var cmd = &cobra.Command{ - Use: "run", - Short: "run [test_name]", - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if !runAll && len(args) == 0 { - printHelp(cmd) - return fmt.Errorf("please provide test to run or --all flag") - } - hubPtr.NucleiTargetHost = NucleiTargetHost - hubPtr.AppSecHost = AppSecHost - if runAll { - if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - } else { - for _, testName := range args { - _, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - } - } - - // set timezone to avoid DST issues - os.Setenv("TZ", "UTC") - for _, test := range hubPtr.Tests { - if csConfig.Cscli.Output == "human" { - log.Infof("Running test '%s'", test.Name) - } - err := test.Run() - if err != nil { - log.Errorf("running test '%s' failed: %+v", test.Name, err) - } - } - - return nil - }, - PersistentPostRunE: func(_ *cobra.Command, _ []string) error { - success := true - testResult := make(map[string]bool) - for _, test := range hubPtr.Tests { - if test.AutoGen && !isAppsecTest { - if test.ParserAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) - fmt.Println() - fmt.Println(test.ParserAssert.AutoGenAssertData) - } - if test.ScenarioAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) - fmt.Println() - fmt.Println(test.ScenarioAssert.AutoGenAssertData) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - fmt.Printf("\nPlease fill your assert file(s) for test '%s', exiting\n", test.Name) - os.Exit(1) - } - testResult[test.Name] = test.Success - if test.Success { - if csConfig.Cscli.Output == "human" { - log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - } else { - success = false - cleanTestEnv := false - if csConfig.Cscli.Output == "human" { - if len(test.ParserAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) - for _, fail := range test.ParserAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if len(test.ScenarioAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) - for _, fail := range test.ScenarioAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if !forceClean && !noClean { - prompt := &survey.Confirm{ - Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), - Default: true, - } - if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { - return fmt.Errorf("unable to ask to remove runtime folder: %s", err) - } - } - } - - if cleanTestEnv || forceClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - } - } - - switch csConfig.Cscli.Output { - case "human": - hubTestResultTable(color.Output, testResult) - case "json": - jsonResult := make(map[string][]string, 0) - jsonResult["success"] = make([]string, 0) - jsonResult["fail"] = make([]string, 0) - for testName, success := range testResult { - if success { - jsonResult["success"] = append(jsonResult["success"], testName) - } else { - jsonResult["fail"] = append(jsonResult["fail"], testName) - } - } - jsonStr, err := json.Marshal(jsonResult) - if err != nil { - return fmt.Errorf("unable to json test result: %s", err) - } - fmt.Println(string(jsonStr)) - default: - return fmt.Errorf("only human/json output modes are supported") - } - - if !success { - os.Exit(1) - } - - return nil - }, - } - - cmd.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") - cmd.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") - cmd.Flags().StringVar(&NucleiTargetHost, "target", hubtest.DefaultNucleiTarget, "Target for AppSec Test") - cmd.Flags().StringVar(&AppSecHost, "host", hubtest.DefaultAppsecHost, "Address to expose AppSec for hubtest") - cmd.Flags().BoolVar(&runAll, "all", false, "Run all tests") - - return cmd -} - -func (cli cliHubTest) NewCleanCmd() *cobra.Command { - var cmd = &cobra.Command{ - Use: "clean", - Short: "clean [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) - } - } - - return nil - }, - } - - return cmd -} - -func (cli cliHubTest) NewInfoCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "info", - Short: "info [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) - } - fmt.Println() - fmt.Printf(" Test name : %s\n", test.Name) - fmt.Printf(" Test path : %s\n", test.Path) - if isAppsecTest { - fmt.Printf(" Nuclei Template : %s\n", test.Config.NucleiTemplate) - fmt.Printf(" Appsec Rules : %s\n", strings.Join(test.Config.AppsecRules, ", ")) - } else { - fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) - fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) - fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) - } - fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) - } - - return nil - }, - } - - return cmd -} - -func (cli cliHubTest) NewListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %s", err) - } - - switch csConfig.Cscli.Output { - case "human": - hubTestListTable(color.Output, hubPtr.Tests) - case "json": - j, err := json.MarshalIndent(hubPtr.Tests, " ", " ") - if err != nil { - return err - } - fmt.Println(string(j)) - default: - return fmt.Errorf("only human/json output modes are supported") - } - - return nil - }, - } - - return cmd -} - -func (cli cliHubTest) NewCoverageCmd() *cobra.Command { - var showParserCov bool - var showScenarioCov bool - var showOnlyPercent bool - var showAppsecCov bool - - cmd := &cobra.Command{ - Use: "coverage", - Short: "coverage", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - //for this one we explicitly don't do for appsec - if err := HubTest.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - var err error - scenarioCoverage := []hubtest.Coverage{} - parserCoverage := []hubtest.Coverage{} - appsecRuleCoverage := []hubtest.Coverage{} - scenarioCoveragePercent := 0 - parserCoveragePercent := 0 - appsecRuleCoveragePercent := 0 - - // if both are false (flag by default), show both - showAll := !showScenarioCov && !showParserCov && !showAppsecCov - - if showParserCov || showAll { - parserCoverage, err = HubTest.GetParsersCoverage() - if err != nil { - return fmt.Errorf("while getting parser coverage: %s", err) - } - parserTested := 0 - for _, test := range parserCoverage { - if test.TestsCount > 0 { - parserTested++ - } - } - parserCoveragePercent = int(math.Round((float64(parserTested) / float64(len(parserCoverage)) * 100))) - } - - if showScenarioCov || showAll { - scenarioCoverage, err = HubTest.GetScenariosCoverage() - if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) - } - - scenarioTested := 0 - for _, test := range scenarioCoverage { - if test.TestsCount > 0 { - scenarioTested++ - } - } - - scenarioCoveragePercent = int(math.Round((float64(scenarioTested) / float64(len(scenarioCoverage)) * 100))) - } - - if showAppsecCov || showAll { - appsecRuleCoverage, err = HubTest.GetAppsecCoverage() - if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) - } - - appsecRuleTested := 0 - for _, test := range appsecRuleCoverage { - if test.TestsCount > 0 { - appsecRuleTested++ - } - } - appsecRuleCoveragePercent = int(math.Round((float64(appsecRuleTested) / float64(len(appsecRuleCoverage)) * 100))) - } - - if showOnlyPercent { - if showAll { - fmt.Printf("parsers=%d%%\nscenarios=%d%%\nappsec_rules=%d%%", parserCoveragePercent, scenarioCoveragePercent, appsecRuleCoveragePercent) - } else if showParserCov { - fmt.Printf("parsers=%d%%", parserCoveragePercent) - } else if showScenarioCov { - fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) - } else if showAppsecCov { - fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) - } - os.Exit(0) - } - - switch csConfig.Cscli.Output { - case "human": - if showParserCov || showAll { - hubTestParserCoverageTable(color.Output, parserCoverage) - } - - if showScenarioCov || showAll { - hubTestScenarioCoverageTable(color.Output, scenarioCoverage) - } - - if showAppsecCov || showAll { - hubTestAppsecRuleCoverageTable(color.Output, appsecRuleCoverage) - } - - fmt.Println() - if showParserCov || showAll { - fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) - } - if showScenarioCov || showAll { - fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) - } - if showAppsecCov || showAll { - fmt.Printf("APPSEC RULES : %d%% of coverage\n", appsecRuleCoveragePercent) - } - case "json": - dump, err := json.MarshalIndent(parserCoverage, "", " ") - if err != nil { - return err - } - fmt.Printf("%s", dump) - dump, err = json.MarshalIndent(scenarioCoverage, "", " ") - if err != nil { - return err - } - fmt.Printf("%s", dump) - dump, err = json.MarshalIndent(appsecRuleCoverage, "", " ") - if err != nil { - return err - } - fmt.Printf("%s", dump) - default: - return fmt.Errorf("only human/json output modes are supported") - } - - return nil - }, - } - - cmd.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") - cmd.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") - cmd.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") - cmd.PersistentFlags().BoolVar(&showAppsecCov, "appsec", false, "Show only appsec coverage") - - return cmd -} - -func (cli cliHubTest) NewEvalCmd() *cobra.Command { - var evalExpression string - - cmd := &cobra.Command{ - Use: "eval", - Short: "eval [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) - } - - output, err := test.ParserAssert.EvalExpression(evalExpression) - if err != nil { - return err - } - - fmt.Print(output) - } - - return nil - }, - } - - cmd.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") - - return cmd -} - -func (cli cliHubTest) NewExplainCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "explain", - Short: "explain [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - if err = test.Run(); err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - - if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { - return fmt.Errorf("unable to load parser result after run: %s", err) - } - } - - err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) - if err != nil { - if err = test.Run(); err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - - if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { - return fmt.Errorf("unable to load scenario result after run: %s", err) - } - } - opts := dumps.DumpOpts{} - dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) - } - - return nil - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/hubtest_table.go b/cmd/crowdsec-cli/hubtest_table.go deleted file mode 100644 index 4034da7e519..00000000000 --- a/cmd/crowdsec-cli/hubtest_table.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/hubtest" -) - -func hubTestResultTable(out io.Writer, testResult map[string]bool) { - t := newLightTable(out) - t.SetHeaders("Test", "Result") - t.SetHeaderAlignment(table.AlignLeft) - t.SetAlignment(table.AlignLeft) - - for testName, success := range testResult { - status := emoji.CheckMarkButton.String() - if !success { - status = emoji.CrossMark.String() - } - - t.AddRow(testName, status) - } - - t.Render() -} - -func hubTestListTable(out io.Writer, tests []*hubtest.HubTestItem) { - t := newLightTable(out) - t.SetHeaders("Name", "Path") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft) - - for _, test := range tests { - t.AddRow(test.Name, test.Path) - } - - t.Render() -} - -func hubTestParserCoverageTable(out io.Writer, coverage []hubtest.Coverage) { - t := newLightTable(out) - t.SetHeaders("Parser", "Status", "Number of tests") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - parserTested := 0 - - for _, test := range coverage { - status := emoji.RedCircle.String() - if test.TestsCount > 0 { - status = emoji.GreenCircle.String() - parserTested++ - } - t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} - -func hubTestAppsecRuleCoverageTable(out io.Writer, coverage []hubtest.Coverage) { - t := newLightTable(out) - t.SetHeaders("Appsec Rule", "Status", "Number of tests") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - parserTested := 0 - - for _, test := range coverage { - status := emoji.RedCircle.String() - if test.TestsCount > 0 { - status = emoji.GreenCircle.String() - parserTested++ - } - t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} - -func hubTestScenarioCoverageTable(out io.Writer, coverage []hubtest.Coverage) { - t := newLightTable(out) - t.SetHeaders("Scenario", "Status", "Number of tests") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - parserTested := 0 - - for _, test := range coverage { - status := emoji.RedCircle.String() - if test.TestsCount > 0 { - status = emoji.GreenCircle.String() - parserTested++ - } - t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/idgen/machineid.go b/cmd/crowdsec-cli/idgen/machineid.go new file mode 100644 index 00000000000..4bd356b3abc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/machineid.go @@ -0,0 +1,48 @@ +package idgen + +import ( + "fmt" + "strings" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/machineid" +) + +// Returns a unique identifier for each crowdsec installation, using an +// identifier of the OS installation where available, otherwise a random +// string. +func generateMachineIDPrefix() (string, error) { + prefix, err := machineid.ID() + if err == nil { + return prefix, nil + } + + log.Debugf("failed to get machine-id with usual files: %s", err) + + bID, err := uuid.NewRandom() + if err == nil { + return bID.String(), nil + } + + return "", fmt.Errorf("generating machine id: %w", err) +} + +// Generate a unique identifier, composed by a prefix and a random suffix. +// The prefix can be provided by a parameter to use in test environments. +func GenerateMachineID(prefix string) (string, error) { + var err error + if prefix == "" { + prefix, err = generateMachineIDPrefix() + } + + if err != nil { + return "", err + } + + prefix = strings.ReplaceAll(prefix, "-", "")[:32] + suffix := GeneratePassword(16) + + return prefix + suffix, nil +} diff --git a/cmd/crowdsec-cli/idgen/password.go b/cmd/crowdsec-cli/idgen/password.go new file mode 100644 index 00000000000..e0faa4daacc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/password.go @@ -0,0 +1,32 @@ +package idgen + +import ( + saferand "crypto/rand" + "math/big" + + log "github.com/sirupsen/logrus" +) + +const PasswordLength = 64 + +func GeneratePassword(length int) string { + upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" + lower := "abcdefghijklmnopqrstuvwxyz" + digits := "0123456789" + + charset := upper + lower + digits + charsetLength := len(charset) + + buf := make([]byte, length) + + for i := range length { + rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) + if err != nil { + log.Fatalf("failed getting data from prng for password generation : %s", err) + } + + buf[i] = charset[rInt.Int64()] + } + + return string(buf) +} diff --git a/cmd/crowdsec-cli/itemcli.go b/cmd/crowdsec-cli/itemcli.go deleted file mode 100644 index 5b0ad13ffe6..00000000000 --- a/cmd/crowdsec-cli/itemcli.go +++ /dev/null @@ -1,588 +0,0 @@ -package main - -import ( - "fmt" - "os" - "strings" - - "github.com/fatih/color" - "github.com/hexops/gotextdiff" - "github.com/hexops/gotextdiff/myers" - "github.com/hexops/gotextdiff/span" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/go-cs-lib/coalesce" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -type cliHelp struct { - // Example is required, the others have a default value - // generated from the item type - use string - short string - long string - example string -} - -type cliItem struct { - name string // plural, as used in the hub index - singular string - oneOrMore string // parenthetical pluralizaion: "parser(s)" - help cliHelp - installHelp cliHelp - removeHelp cliHelp - upgradeHelp cliHelp - inspectHelp cliHelp - inspectDetail func(item *cwhub.Item) error - listHelp cliHelp -} - -func (cli cliItem) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.help.use, fmt.Sprintf("%s [item]...", cli.name)), - Short: coalesce.String(cli.help.short, fmt.Sprintf("Manage hub %s", cli.name)), - Long: cli.help.long, - Example: cli.help.example, - Args: cobra.MinimumNArgs(1), - Aliases: []string{cli.singular}, - DisableAutoGenTag: true, - } - - cmd.AddCommand(cli.NewInstallCmd()) - cmd.AddCommand(cli.NewRemoveCmd()) - cmd.AddCommand(cli.NewUpgradeCmd()) - cmd.AddCommand(cli.NewInspectCmd()) - cmd.AddCommand(cli.NewListCmd()) - - return cmd -} - -func (cli cliItem) Install(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - downloadOnly, err := flags.GetBool("download-only") - if err != nil { - return err - } - - force, err := flags.GetBool("force") - if err != nil { - return err - } - - ignoreError, err := flags.GetBool("ignore") - if err != nil { - return err - } - - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), log.StandardLogger()) - if err != nil { - return err - } - - for _, name := range args { - item := hub.GetItem(cli.name, name) - if item == nil { - msg := suggestNearestMessage(hub, cli.name, name) - if !ignoreError { - return fmt.Errorf(msg) - } - - log.Errorf(msg) - - continue - } - - if err := item.Install(force, downloadOnly); err != nil { - if !ignoreError { - return fmt.Errorf("error while installing '%s': %w", item.Name, err) - } - - log.Errorf("Error while installing '%s': %s", item.Name, err) - } - } - - log.Infof(ReloadMessage()) - - return nil -} - -func (cli cliItem) NewInstallCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.installHelp.use, "install [item]..."), - Short: coalesce.String(cli.installHelp.short, fmt.Sprintf("Install given %s", cli.oneOrMore)), - Long: coalesce.String(cli.installHelp.long, fmt.Sprintf("Fetch and install one or more %s from the hub", cli.name)), - Example: cli.installHelp.example, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compAllItems(cli.name, args, toComplete) - }, - RunE: cli.Install, - } - - flags := cmd.Flags() - flags.BoolP("download-only", "d", false, "Only download packages, don't enable") - flags.Bool("force", false, "Force install: overwrite tainted and outdated files") - flags.Bool("ignore", false, fmt.Sprintf("Ignore errors when installing multiple %s", cli.name)) - - return cmd -} - -// return the names of the installed parents of an item, used to check if we can remove it -func istalledParentNames(item *cwhub.Item) []string { - ret := make([]string, 0) - - for _, parent := range item.Ancestors() { - if parent.State.Installed { - ret = append(ret, parent.Name) - } - } - - return ret -} - -func (cli cliItem) Remove(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - purge, err := flags.GetBool("purge") - if err != nil { - return err - } - - force, err := flags.GetBool("force") - if err != nil { - return err - } - - all, err := flags.GetBool("all") - if err != nil { - return err - } - - hub, err := require.Hub(csConfig, nil, log.StandardLogger()) - if err != nil { - return err - } - - if all { - getter := hub.GetInstalledItems - if purge { - getter = hub.GetAllItems - } - - items, err := getter(cli.name) - if err != nil { - return err - } - - removed := 0 - - for _, item := range items { - didRemove, err := item.Remove(purge, force) - if err != nil { - return err - } - - if didRemove { - log.Infof("Removed %s", item.Name) - removed++ - } - } - - log.Infof("Removed %d %s", removed, cli.name) - - if removed > 0 { - log.Infof(ReloadMessage()) - } - - return nil - } - - if len(args) == 0 { - return fmt.Errorf("specify at least one %s to remove or '--all'", cli.singular) - } - - removed := 0 - - for _, itemName := range args { - item := hub.GetItem(cli.name, itemName) - if item == nil { - return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) - } - - parents := istalledParentNames(item) - - if !force && len(parents) > 0 { - log.Warningf("%s belongs to collections: %s", item.Name, parents) - log.Warningf("Run 'sudo cscli %s remove %s --force' if you want to force remove this %s", item.Type, item.Name, cli.singular) - - continue - } - - didRemove, err := item.Remove(purge, force) - if err != nil { - return err - } - - if didRemove { - log.Infof("Removed %s", item.Name) - removed++ - } - } - - log.Infof("Removed %d %s", removed, cli.name) - - if removed > 0 { - log.Infof(ReloadMessage()) - } - - return nil -} - -func (cli cliItem) NewRemoveCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.removeHelp.use, "remove [item]..."), - Short: coalesce.String(cli.removeHelp.short, fmt.Sprintf("Remove given %s", cli.oneOrMore)), - Long: coalesce.String(cli.removeHelp.long, fmt.Sprintf("Remove one or more %s", cli.name)), - Example: cli.removeHelp.example, - Aliases: []string{"delete"}, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cli.name, args, toComplete) - }, - RunE: cli.Remove, - } - - flags := cmd.Flags() - flags.Bool("purge", false, "Delete source file too") - flags.Bool("force", false, "Force remove: remove tainted and outdated files") - flags.Bool("all", false, fmt.Sprintf("Remove all the %s", cli.name)) - - return cmd -} - -func (cli cliItem) Upgrade(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - force, err := flags.GetBool("force") - if err != nil { - return err - } - - all, err := flags.GetBool("all") - if err != nil { - return err - } - - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), log.StandardLogger()) - if err != nil { - return err - } - - if all { - items, err := hub.GetInstalledItems(cli.name) - if err != nil { - return err - } - - updated := 0 - - for _, item := range items { - didUpdate, err := item.Upgrade(force) - if err != nil { - return err - } - - if didUpdate { - updated++ - } - } - - log.Infof("Updated %d %s", updated, cli.name) - - if updated > 0 { - log.Infof(ReloadMessage()) - } - - return nil - } - - if len(args) == 0 { - return fmt.Errorf("specify at least one %s to upgrade or '--all'", cli.singular) - } - - updated := 0 - - for _, itemName := range args { - item := hub.GetItem(cli.name, itemName) - if item == nil { - return fmt.Errorf("can't find '%s' in %s", itemName, cli.name) - } - - didUpdate, err := item.Upgrade(force) - if err != nil { - return err - } - - if didUpdate { - log.Infof("Updated %s", item.Name) - updated++ - } - } - - if updated > 0 { - log.Infof(ReloadMessage()) - } - - return nil -} - -func (cli cliItem) NewUpgradeCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.upgradeHelp.use, "upgrade [item]..."), - Short: coalesce.String(cli.upgradeHelp.short, fmt.Sprintf("Upgrade given %s", cli.oneOrMore)), - Long: coalesce.String(cli.upgradeHelp.long, fmt.Sprintf("Fetch and upgrade one or more %s from the hub", cli.name)), - Example: cli.upgradeHelp.example, - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cli.name, args, toComplete) - }, - RunE: cli.Upgrade, - } - - flags := cmd.Flags() - flags.BoolP("all", "a", false, fmt.Sprintf("Upgrade all the %s", cli.name)) - flags.Bool("force", false, "Force upgrade: overwrite tainted and outdated files") - - return cmd -} - -func (cli cliItem) Inspect(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - url, err := flags.GetString("url") - if err != nil { - return err - } - - if url != "" { - csConfig.Cscli.PrometheusUrl = url - } - - diff, err := flags.GetBool("diff") - if err != nil { - return err - } - - rev, err := flags.GetBool("rev") - if err != nil { - return err - } - - noMetrics, err := flags.GetBool("no-metrics") - if err != nil { - return err - } - - remote := (*cwhub.RemoteHubCfg)(nil) - - if diff { - remote = require.RemoteHub(csConfig) - } - - hub, err := require.Hub(csConfig, remote, log.StandardLogger()) - if err != nil { - return err - } - - for _, name := range args { - item := hub.GetItem(cli.name, name) - if item == nil { - return fmt.Errorf("can't find '%s' in %s", name, cli.name) - } - - if diff { - fmt.Println(cli.whyTainted(hub, item, rev)) - - continue - } - - if err = InspectItem(item, !noMetrics); err != nil { - return err - } - - if cli.inspectDetail != nil { - if err = cli.inspectDetail(item); err != nil { - return err - } - } - } - - return nil -} - -func (cli cliItem) NewInspectCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.inspectHelp.use, "inspect [item]..."), - Short: coalesce.String(cli.inspectHelp.short, fmt.Sprintf("Inspect given %s", cli.oneOrMore)), - Long: coalesce.String(cli.inspectHelp.long, fmt.Sprintf("Inspect the state of one or more %s", cli.name)), - Example: cli.inspectHelp.example, - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return compInstalledItems(cli.name, args, toComplete) - }, - PreRunE: func(cmd *cobra.Command, _ []string) error { - flags := cmd.Flags() - - diff, err := flags.GetBool("diff") - if err != nil { - return err - } - - rev, err := flags.GetBool("rev") - if err != nil { - return err - } - - if rev && !diff { - return fmt.Errorf("--rev can only be used with --diff") - } - - return nil - }, - RunE: cli.Inspect, - } - - flags := cmd.Flags() - flags.StringP("url", "u", "", "Prometheus url") - flags.Bool("diff", false, "Show diff with latest version (for tainted items)") - flags.Bool("rev", false, "Reverse diff output") - flags.Bool("no-metrics", false, "Don't show metrics (when cscli.output=human)") - - return cmd -} - -func (cli cliItem) List(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - all, err := flags.GetBool("all") - if err != nil { - return err - } - - hub, err := require.Hub(csConfig, nil, log.StandardLogger()) - if err != nil { - return err - } - - items := make(map[string][]*cwhub.Item) - - items[cli.name], err = selectItems(hub, cli.name, args, !all) - if err != nil { - return err - } - - if err = listItems(color.Output, []string{cli.name}, items, false); err != nil { - return err - } - - return nil -} - -func (cli cliItem) NewListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: coalesce.String(cli.listHelp.use, "list [item... | -a]"), - Short: coalesce.String(cli.listHelp.short, fmt.Sprintf("List %s", cli.oneOrMore)), - Long: coalesce.String(cli.listHelp.long, fmt.Sprintf("List of installed/available/specified %s", cli.name)), - Example: cli.listHelp.example, - DisableAutoGenTag: true, - RunE: cli.List, - } - - flags := cmd.Flags() - flags.BoolP("all", "a", false, "List disabled items as well") - - return cmd -} - -// return the diff between the installed version and the latest version -func (cli cliItem) itemDiff(item *cwhub.Item, reverse bool) (string, error) { - if !item.State.Installed { - return "", fmt.Errorf("'%s' is not installed", item.FQName()) - } - - latestContent, remoteURL, err := item.FetchLatest() - if err != nil { - return "", err - } - - localContent, err := os.ReadFile(item.State.LocalPath) - if err != nil { - return "", fmt.Errorf("while reading %s: %w", item.State.LocalPath, err) - } - - file1 := item.State.LocalPath - file2 := remoteURL - content1 := string(localContent) - content2 := string(latestContent) - - if reverse { - file1, file2 = file2, file1 - content1, content2 = content2, content1 - } - - edits := myers.ComputeEdits(span.URIFromPath(file1), content1, content2) - diff := gotextdiff.ToUnified(file1, file2, content1, edits) - - return fmt.Sprintf("%s", diff), nil -} - -func (cli cliItem) whyTainted(hub *cwhub.Hub, item *cwhub.Item, reverse bool) string { - if !item.State.Installed { - return fmt.Sprintf("# %s is not installed", item.FQName()) - } - - if !item.State.Tainted { - return fmt.Sprintf("# %s is not tainted", item.FQName()) - } - - if len(item.State.TaintedBy) == 0 { - return fmt.Sprintf("# %s is tainted but we don't know why. please report this as a bug", item.FQName()) - } - - ret := []string{ - fmt.Sprintf("# Let's see why %s is tainted.", item.FQName()), - } - - for _, fqsub := range item.State.TaintedBy { - ret = append(ret, fmt.Sprintf("\n-> %s\n", fqsub)) - - sub, err := hub.GetItemFQ(fqsub) - if err != nil { - ret = append(ret, err.Error()) - } - - diff, err := cli.itemDiff(sub, reverse) - if err != nil { - ret = append(ret, err.Error()) - } - - if diff != "" { - ret = append(ret, diff) - } else if len(sub.State.TaintedBy) > 0 { - taintList := strings.Join(sub.State.TaintedBy, ", ") - if sub.FQName() == taintList { - // hack: avoid message "item is tainted by itself" - continue - } - ret = append(ret, fmt.Sprintf("# %s is tainted by %s", sub.FQName(), taintList)) - } - } - - return strings.Join(ret, "\n") -} diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go deleted file mode 100644 index 581683baa8f..00000000000 --- a/cmd/crowdsec-cli/machines.go +++ /dev/null @@ -1,487 +0,0 @@ -package main - -import ( - saferand "crypto/rand" - "encoding/csv" - "encoding/json" - "fmt" - "io" - "math/big" - "os" - "strings" - "time" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - "slices" - - "github.com/crowdsecurity/machineid" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -const passwordLength = 64 - -func generatePassword(length int) string { - upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" - lower := "abcdefghijklmnopqrstuvwxyz" - digits := "0123456789" - - charset := upper + lower + digits - charsetLength := len(charset) - - buf := make([]byte, length) - - for i := 0; i < length; i++ { - rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) - if err != nil { - log.Fatalf("failed getting data from prng for password generation : %s", err) - } - buf[i] = charset[rInt.Int64()] - } - - return string(buf) -} - -// Returns a unique identifier for each crowdsec installation, using an -// identifier of the OS installation where available, otherwise a random -// string. -func generateIDPrefix() (string, error) { - prefix, err := machineid.ID() - if err == nil { - return prefix, nil - } - log.Debugf("failed to get machine-id with usual files: %s", err) - - bID, err := uuid.NewRandom() - if err == nil { - return bID.String(), nil - } - return "", fmt.Errorf("generating machine id: %w", err) -} - -// Generate a unique identifier, composed by a prefix and a random suffix. -// The prefix can be provided by a parameter to use in test environments. -func generateID(prefix string) (string, error) { - var err error - if prefix == "" { - prefix, err = generateIDPrefix() - } - if err != nil { - return "", err - } - prefix = strings.ReplaceAll(prefix, "-", "")[:32] - suffix := generatePassword(16) - return prefix + suffix, nil -} - -// getLastHeartbeat returns the last heartbeat timestamp of a machine -// and a boolean indicating if the machine is considered active or not. -func getLastHeartbeat(m *ent.Machine) (string, bool) { - if m.LastHeartbeat == nil { - return "-", false - } - - elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) - - hb := elapsed.Truncate(time.Second).String() - if elapsed > 2*time.Minute { - return hb, false - } - - return hb, true -} - -func getAgents(out io.Writer, dbClient *database.Client) error { - machines, err := dbClient.ListMachines() - if err != nil { - return fmt.Errorf("unable to list machines: %s", err) - } - - switch csConfig.Cscli.Output { - case "human": - getAgentsTable(out, machines) - case "json": - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - if err := enc.Encode(machines); err != nil { - return fmt.Errorf("failed to marshal") - } - return nil - case "raw": - csvwriter := csv.NewWriter(out) - err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"}) - if err != nil { - return fmt.Errorf("failed to write header: %s", err) - } - for _, m := range machines { - validated := "false" - if m.IsValidated { - validated = "true" - } - hb, _ := getLastHeartbeat(m) - err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb}) - if err != nil { - return fmt.Errorf("failed to write raw output: %w", err) - } - } - csvwriter.Flush() - default: - return fmt.Errorf("unknown output '%s'", csConfig.Cscli.Output) - } - return nil -} - -type cliMachines struct{} - -func NewCLIMachines() *cliMachines { - return &cliMachines{} -} - -func (cli cliMachines) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "machines [action]", - Short: "Manage local API machines [requires local API]", - Long: `To list/add/delete/validate/prune machines. -Note: This command requires database direct access, so is intended to be run on the local API machine. -`, - Example: `cscli machines [action]`, - DisableAutoGenTag: true, - Aliases: []string{"machine"}, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - var err error - if err = require.LAPI(csConfig); err != nil { - return err - } - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } - return nil - }, - } - - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewAddCmd()) - cmd.AddCommand(cli.NewDeleteCmd()) - cmd.AddCommand(cli.NewValidateCmd()) - cmd.AddCommand(cli.NewPruneCmd()) - - return cmd -} - -func (cli cliMachines) NewListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list all machines in the database", - Long: `list all machines in the database with their status and last heartbeat`, - Example: `cscli machines list`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - err := getAgents(color.Output, dbClient) - if err != nil { - return fmt.Errorf("unable to list machines: %s", err) - } - - return nil - }, - } - - return cmd -} - -func (cli cliMachines) NewAddCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "add", - Short: "add a single machine to the database", - DisableAutoGenTag: true, - Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, - Example: ` -cscli machines add --auto -cscli machines add MyTestMachine --auto -cscli machines add MyTestMachine --password MyPassword -`, - RunE: cli.add, - } - - flags := cmd.Flags() - flags.StringP("password", "p", "", "machine password to login to the API") - flags.StringP("file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") - flags.StringP("url", "u", "", "URL of the local API") - flags.BoolP("interactive", "i", false, "interfactive mode to enter the password") - flags.BoolP("auto", "a", false, "automatically generate password (and username if not provided)") - flags.Bool("force", false, "will force add the machine if it already exist") - - return cmd -} - -func (cli cliMachines) add(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - machinePassword, err := flags.GetString("password") - if err != nil { - return err - } - - dumpFile, err := flags.GetString("file") - if err != nil { - return err - } - - apiURL, err := flags.GetString("url") - if err != nil { - return err - } - - interactive, err := flags.GetBool("interactive") - if err != nil { - return err - } - - autoAdd, err := flags.GetBool("auto") - if err != nil { - return err - } - - force, err := flags.GetBool("force") - if err != nil { - return err - } - - var machineID string - - // create machineID if not specified by user - if len(args) == 0 { - if !autoAdd { - printHelp(cmd) - return nil - } - machineID, err = generateID("") - if err != nil { - return fmt.Errorf("unable to generate machine id: %s", err) - } - } else { - machineID = args[0] - } - - /*check if file already exists*/ - if dumpFile == "" && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { - credFile := csConfig.API.Client.CredentialsFilePath - // use the default only if the file does not exist - _, err = os.Stat(credFile) - - switch { - case os.IsNotExist(err) || force: - dumpFile = csConfig.API.Client.CredentialsFilePath - case err != nil: - return fmt.Errorf("unable to stat '%s': %s", credFile, err) - default: - return fmt.Errorf(`credentials file '%s' already exists: please remove it, use "--force" or specify a different file with "-f" ("-f -" for standard output)`, credFile) - } - } - - if dumpFile == "" { - return fmt.Errorf(`please specify a file to dump credentials to, with -f ("-f -" for standard output)`) - } - - // create a password if it's not specified by user - if machinePassword == "" && !interactive { - if !autoAdd { - return fmt.Errorf("please specify a password with --password or use --auto") - } - machinePassword = generatePassword(passwordLength) - } else if machinePassword == "" && interactive { - qs := &survey.Password{ - Message: "Please provide a password for the machine", - } - survey.AskOne(qs, &machinePassword) - } - password := strfmt.Password(machinePassword) - _, err = dbClient.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType) - if err != nil { - return fmt.Errorf("unable to create machine: %s", err) - } - fmt.Printf("Machine '%s' successfully added to the local API.\n", machineID) - - if apiURL == "" { - if csConfig.API.Client != nil && csConfig.API.Client.Credentials != nil && csConfig.API.Client.Credentials.URL != "" { - apiURL = csConfig.API.Client.Credentials.URL - } else if csConfig.API.Server != nil && csConfig.API.Server.ListenURI != "" { - apiURL = "http://" + csConfig.API.Server.ListenURI - } else { - return fmt.Errorf("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") - } - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: machineID, - Password: password.String(), - URL: apiURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %s", err) - } - if dumpFile != "" && dumpFile != "-" { - err = os.WriteFile(dumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", dumpFile, err) - } - fmt.Printf("API credentials written to '%s'.\n", dumpFile) - } else { - fmt.Printf("%s\n", string(apiConfigDump)) - } - - return nil -} - -func (cli cliMachines) NewDeleteCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "delete [machine_name]...", - Short: "delete machine(s) by name", - Example: `cscli machines delete "machine1" "machine2"`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - ValidArgsFunction: func(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - machines, err := dbClient.ListMachines() - if err != nil { - cobra.CompError("unable to list machines " + err.Error()) - } - ret := make([]string, 0) - for _, machine := range machines { - if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { - ret = append(ret, machine.MachineId) - } - } - return ret, cobra.ShellCompDirectiveNoFileComp - }, - RunE: cli.delete, - } - - return cmd -} - -func (cli cliMachines) delete(_ *cobra.Command, args []string) error { - for _, machineID := range args { - err := dbClient.DeleteWatcher(machineID) - if err != nil { - log.Errorf("unable to delete machine '%s': %s", machineID, err) - return nil - } - log.Infof("machine '%s' deleted successfully", machineID) - } - - return nil -} - -func (cli cliMachines) NewPruneCmd() *cobra.Command { - var parsedDuration time.Duration - cmd := &cobra.Command{ - Use: "prune", - Short: "prune multiple machines from the database", - Long: `prune multiple machines that are not validated or have not connected to the local API in a given duration.`, - Example: `cscli machines prune -cscli machines prune --duration 1h -cscli machines prune --not-validated-only --force`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - PreRunE: func(cmd *cobra.Command, _ []string) error { - dur, _ := cmd.Flags().GetString("duration") - var err error - parsedDuration, err = time.ParseDuration(fmt.Sprintf("-%s", dur)) - if err != nil { - return fmt.Errorf("unable to parse duration '%s': %s", dur, err) - } - return nil - }, - RunE: func(cmd *cobra.Command, _ []string) error { - notValidOnly, _ := cmd.Flags().GetBool("not-validated-only") - force, _ := cmd.Flags().GetBool("force") - if parsedDuration >= 0-60*time.Second && !notValidOnly { - var answer bool - prompt := &survey.Confirm{ - Message: "The duration you provided is less than or equal 60 seconds this can break installations do you want to continue ?", - Default: false, - } - if err := survey.AskOne(prompt, &answer); err != nil { - return fmt.Errorf("unable to ask about prune check: %s", err) - } - if !answer { - fmt.Println("user aborted prune no changes were made") - return nil - } - } - machines := make([]*ent.Machine, 0) - if pending, err := dbClient.QueryPendingMachine(); err == nil { - machines = append(machines, pending...) - } - if !notValidOnly { - if pending, err := dbClient.QueryLastValidatedHeartbeatLT(time.Now().UTC().Add(parsedDuration)); err == nil { - machines = append(machines, pending...) - } - } - if len(machines) == 0 { - fmt.Println("no machines to prune") - return nil - } - getAgentsTable(color.Output, machines) - if !force { - var answer bool - prompt := &survey.Confirm{ - Message: "You are about to PERMANENTLY remove the above machines from the database these will NOT be recoverable, continue ?", - Default: false, - } - if err := survey.AskOne(prompt, &answer); err != nil { - return fmt.Errorf("unable to ask about prune check: %s", err) - } - if !answer { - fmt.Println("user aborted prune no changes were made") - return nil - } - } - nbDeleted, err := dbClient.BulkDeleteWatchers(machines) - if err != nil { - return fmt.Errorf("unable to prune machines: %s", err) - } - fmt.Printf("successfully delete %d machines\n", nbDeleted) - return nil - }, - } - cmd.Flags().StringP("duration", "d", "10m", "duration of time since validated machine last heartbeat") - cmd.Flags().Bool("not-validated-only", false, "only prune machines that are not validated") - cmd.Flags().Bool("force", false, "force prune without asking for confirmation") - - return cmd -} - -func (cli cliMachines) NewValidateCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "validate", - Short: "validate a machine to access the local API", - Long: `validate a machine to access the local API.`, - Example: `cscli machines validate "machine_name"`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - machineID := args[0] - if err := dbClient.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine '%s': %s", machineID, err) - } - log.Infof("machine '%s' validated successfully", machineID) - - return nil - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/machines_table.go b/cmd/crowdsec-cli/machines_table.go deleted file mode 100644 index e166fb785a6..00000000000 --- a/cmd/crowdsec-cli/machines_table.go +++ /dev/null @@ -1,35 +0,0 @@ -package main - -import ( - "io" - "time" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/database/ent" -) - -func getAgentsTable(out io.Writer, machines []*ent.Machine) { - t := newLightTable(out) - t.SetHeaders("Name", "IP Address", "Last Update", "Status", "Version", "Auth Type", "Last Heartbeat") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, m := range machines { - var validated string - if m.IsValidated { - validated = emoji.CheckMark.String() - } else { - validated = emoji.Prohibited.String() - } - - hb, active := getLastHeartbeat(m) - if !active { - hb = emoji.Warning.String() + " " + hb - } - t.AddRow(m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb) - } - - t.Render() -} diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index fda4cddc2bc..1cca03b1d3d 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -1,7 +1,9 @@ package main import ( + "fmt" "os" + "path/filepath" "slices" "time" @@ -10,48 +12,116 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliconsole" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clidecision" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliexplain" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihubtest" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliitem" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clinotifications" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisimulation" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisupport" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -var trace_lvl, dbg_lvl, nfo_lvl, wrn_lvl, err_lvl bool - -var ConfigFilePath string -var csConfig *csconfig.Config -var dbClient *database.Client +var ( + ConfigFilePath string + csConfig *csconfig.Config +) -var OutputFormat string -var OutputColor string +type configGetter func() *csconfig.Config var mergedConfig string -// flagBranch overrides the value in csConfig.Cscli.HubBranch -var flagBranch = "" +type cliRoot struct { + logTrace bool + logDebug bool + logInfo bool + logWarn bool + logErr bool + outputColor string + outputFormat string + // flagBranch overrides the value in csConfig.Cscli.HubBranch + flagBranch string +} -func initConfig() { - var err error +func newCliRoot() *cliRoot { + return &cliRoot{} +} - if trace_lvl { - log.SetLevel(log.TraceLevel) - } else if dbg_lvl { - log.SetLevel(log.DebugLevel) - } else if nfo_lvl { - log.SetLevel(log.InfoLevel) - } else if wrn_lvl { - log.SetLevel(log.WarnLevel) - } else if err_lvl { - log.SetLevel(log.ErrorLevel) +// cfg() is a helper function to get the configuration loaded from config.yaml, +// we pass it to subcommands because the file is not read until the Execute() call +func (cli *cliRoot) cfg() *csconfig.Config { + return csConfig +} + +// wantedLogLevel returns the log level requested in the command line flags. +func (cli *cliRoot) wantedLogLevel() log.Level { + switch { + case cli.logTrace: + return log.TraceLevel + case cli.logDebug: + return log.DebugLevel + case cli.logInfo: + return log.InfoLevel + case cli.logWarn: + return log.WarnLevel + case cli.logErr: + return log.ErrorLevel + default: + return log.InfoLevel + } +} + +// loadConfigFor loads the configuration file for the given sub-command. +// If the sub-command does not need it, it returns a default configuration. +func loadConfigFor(command string) (*csconfig.Config, string, error) { + noNeedConfig := []string{ + "doc", + "help", + "completion", + "version", + "hubtest", } - if !slices.Contains(NoNeedConfig, os.Args[1]) { + if !slices.Contains(noNeedConfig, command) { log.Debugf("Using %s as configuration file", ConfigFilePath) - csConfig, mergedConfig, err = csconfig.NewConfig(ConfigFilePath, false, false, true) + + config, merged, err := csconfig.NewConfig(ConfigFilePath, false, false, true) if err != nil { - log.Fatal(err) + return nil, "", err } - } else { - csConfig = csconfig.NewDefaultConfig() + + // set up directory for trace files + if err := trace.Init(filepath.Join(config.ConfigPaths.DataDir, "trace")); err != nil { + return nil, "", fmt.Errorf("while setting up trace directory: %w", err) + } + + return config, merged, nil + } + + return csconfig.NewDefaultConfig(), "", nil +} + +// initialize is called before the subcommand is executed. +func (cli *cliRoot) initialize() error { + var err error + + log.SetLevel(cli.wantedLogLevel()) + + csConfig, mergedConfig, err = loadConfigFor(os.Args[1]) + if err != nil { + return err } // recap of the enabled feature flags, because logging @@ -60,20 +130,24 @@ func initConfig() { log.Debugf("Enabled feature flags: %s", fflist) } - if flagBranch != "" { - csConfig.Cscli.HubBranch = flagBranch + if cli.flagBranch != "" { + csConfig.Cscli.HubBranch = cli.flagBranch } - if OutputFormat != "" { - csConfig.Cscli.Output = OutputFormat - - if OutputFormat != "json" && OutputFormat != "raw" && OutputFormat != "human" { - log.Fatalf("output format %s unknown", OutputFormat) - } + if cli.outputFormat != "" { + csConfig.Cscli.Output = cli.outputFormat } + if csConfig.Cscli.Output == "" { csConfig.Cscli.Output = "human" } + + if csConfig.Cscli.Output != "human" && csConfig.Cscli.Output != "json" && csConfig.Cscli.Output != "raw" { + return fmt.Errorf("output format '%s' not supported: must be one of human, json, raw", csConfig.Cscli.Output) + } + + log.SetFormatter(&log.TextFormatter{DisableTimestamp: true}) + if csConfig.Cscli.Output == "json" { log.SetFormatter(&log.JSONFormatter{}) log.SetLevel(log.ErrorLevel) @@ -81,42 +155,54 @@ func initConfig() { log.SetLevel(log.ErrorLevel) } - if OutputColor != "" { - csConfig.Cscli.Color = OutputColor + if cli.outputColor != "" { + csConfig.Cscli.Color = cli.outputColor - if OutputColor != "yes" && OutputColor != "no" && OutputColor != "auto" { - log.Fatalf("output color %s unknown", OutputColor) + if cli.outputColor != "yes" && cli.outputColor != "no" && cli.outputColor != "auto" { + return fmt.Errorf("output color '%s' not supported: must be one of yes, no, auto", cli.outputColor) } } -} -// list of valid subcommands for the shell completion -var validArgs = []string{ - "alerts", "appsec-configs", "appsec-rules", "bouncers", "capi", "collections", - "completion", "config", "console", "contexts", "dashboard", "decisions", "explain", - "hub", "hubtest", "lapi", "machines", "metrics", "notifications", "parsers", - "postoverflows", "scenarios", "simulation", "support", "version", + return nil } -var NoNeedConfig = []string{ - "doc", - "help", - "completion", - "version", - "hubtest", +func (cli *cliRoot) colorize(cmd *cobra.Command) { + cc.Init(&cc.Config{ + RootCmd: cmd, + Headings: cc.Yellow, + Commands: cc.Green + cc.Bold, + CmdShortDescr: cc.Cyan, + Example: cc.Italic, + ExecName: cc.Bold, + Aliases: cc.Bold + cc.Italic, + FlagsDataType: cc.White, + Flags: cc.Green, + FlagsDescr: cc.Cyan, + NoExtraNewlines: true, + NoBottomNewline: true, + }) + cmd.SetOut(color.Output) } -func main() { +func (cli *cliRoot) NewCommand() (*cobra.Command, error) { // set the formatter asap and worry about level later logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true} log.SetFormatter(logFormatter) if err := fflag.RegisterAllFeatures(); err != nil { - log.Fatalf("failed to register features: %s", err) + return nil, fmt.Errorf("failed to register features: %w", err) } if err := csconfig.LoadFeatureFlagsEnv(log.StandardLogger()); err != nil { - log.Fatalf("failed to set feature flags from env: %s", err) + return nil, fmt.Errorf("failed to set feature flags from env: %w", err) + } + + // list of valid subcommands for the shell completion + validArgs := []string{ + "alerts", "appsec-configs", "appsec-rules", "bouncers", "capi", "collections", + "completion", "config", "console", "contexts", "dashboard", "decisions", "explain", + "hub", "hubtest", "lapi", "machines", "metrics", "notifications", "parsers", + "postoverflows", "scenarios", "simulation", "support", "version", } cmd := &cobra.Command{ @@ -131,33 +217,25 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall /*TBD examples*/ } - cc.Init(&cc.Config{ - RootCmd: cmd, - Headings: cc.Yellow, - Commands: cc.Green + cc.Bold, - CmdShortDescr: cc.Cyan, - Example: cc.Italic, - ExecName: cc.Bold, - Aliases: cc.Bold + cc.Italic, - FlagsDataType: cc.White, - Flags: cc.Green, - FlagsDescr: cc.Cyan, - }) - cmd.SetOut(color.Output) + cli.colorize(cmd) - cmd.PersistentFlags().StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") - cmd.PersistentFlags().StringVarP(&OutputFormat, "output", "o", "", "Output format: human, json, raw") - cmd.PersistentFlags().StringVarP(&OutputColor, "color", "", "auto", "Output color: yes, no, auto") - cmd.PersistentFlags().BoolVar(&dbg_lvl, "debug", false, "Set logging to debug") - cmd.PersistentFlags().BoolVar(&nfo_lvl, "info", false, "Set logging to info") - cmd.PersistentFlags().BoolVar(&wrn_lvl, "warning", false, "Set logging to warning") - cmd.PersistentFlags().BoolVar(&err_lvl, "error", false, "Set logging to error") - cmd.PersistentFlags().BoolVar(&trace_lvl, "trace", false, "Set logging to trace") + /*don't sort flags so we can enforce order*/ + cmd.Flags().SortFlags = false - cmd.PersistentFlags().StringVar(&flagBranch, "branch", "", "Override hub branch on github") - if err := cmd.PersistentFlags().MarkHidden("branch"); err != nil { - log.Fatalf("failed to hide flag: %s", err) - } + pflags := cmd.PersistentFlags() + pflags.SortFlags = false + + pflags.StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file") + pflags.StringVarP(&cli.outputFormat, "output", "o", "", "Output format: human, json, raw") + pflags.StringVarP(&cli.outputColor, "color", "", "auto", "Output color: yes, no, auto") + pflags.BoolVar(&cli.logDebug, "debug", false, "Set logging to debug") + pflags.BoolVar(&cli.logInfo, "info", false, "Set logging to info") + pflags.BoolVar(&cli.logWarn, "warning", false, "Set logging to warning") + pflags.BoolVar(&cli.logErr, "error", false, "Set logging to error") + pflags.BoolVar(&cli.logTrace, "trace", false, "Set logging to trace") + pflags.StringVar(&cli.flagBranch, "branch", "", "Override hub branch on github") + + _ = pflags.MarkHidden("branch") // Look for "-c /path/to/config.yaml" // This duplicates the logic in cobra, but we need to do it before @@ -171,56 +249,56 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall } if err := csconfig.LoadFeatureFlagsFile(ConfigFilePath, log.StandardLogger()); err != nil { - log.Fatal(err) + return nil, err } + cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) + cmd.AddCommand(NewCLIVersion().NewCommand()) + cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) + cmd.AddCommand(clihub.New(cli.cfg).NewCommand()) + cmd.AddCommand(climetrics.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) + cmd.AddCommand(clidecision.New(cli.cfg).NewCommand()) + cmd.AddCommand(clialert.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisimulation.New(cli.cfg).NewCommand()) + cmd.AddCommand(clibouncer.New(cli.cfg).NewCommand()) + cmd.AddCommand(climachine.New(cli.cfg).NewCommand()) + cmd.AddCommand(clicapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(clilapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCompletionCmd()) + cmd.AddCommand(cliconsole.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliexplain.New(cli.cfg, ConfigFilePath).NewCommand()) + cmd.AddCommand(clihubtest.New(cli.cfg).NewCommand()) + cmd.AddCommand(clinotifications.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisupport.New(cli.cfg).NewCommand()) + cmd.AddCommand(clipapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewCollection(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewParser(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewScenario(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewPostOverflow(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewContext(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecConfig(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecRule(cli.cfg).NewCommand()) + + cli.addSetup(cmd) + if len(os.Args) > 1 { - cobra.OnInitialize(initConfig) + cobra.OnInitialize( + func() { + if err := cli.initialize(); err != nil { + log.Fatal(err) + } + }, + ) } - /*don't sort flags so we can enforce order*/ - cmd.Flags().SortFlags = false - cmd.PersistentFlags().SortFlags = false - - // we use a getter because the config is not initialized until the Execute() call - getconfig := func() *csconfig.Config { - return csConfig - } + return cmd, nil +} - cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) - cmd.AddCommand(NewCLIVersion().NewCommand()) - cmd.AddCommand(NewConfigCmd()) - cmd.AddCommand(NewCLIHub().NewCommand()) - cmd.AddCommand(NewMetricsCmd()) - cmd.AddCommand(NewCLIDashboard().NewCommand()) - cmd.AddCommand(NewCLIDecisions().NewCommand()) - cmd.AddCommand(NewCLIAlerts().NewCommand()) - cmd.AddCommand(NewCLISimulation().NewCommand()) - cmd.AddCommand(NewCLIBouncers(getconfig).NewCommand()) - cmd.AddCommand(NewCLIMachines().NewCommand()) - cmd.AddCommand(NewCLICapi().NewCommand()) - cmd.AddCommand(NewLapiCmd()) - cmd.AddCommand(NewCompletionCmd()) - cmd.AddCommand(NewConsoleCmd()) - cmd.AddCommand(NewCLIExplain().NewCommand()) - cmd.AddCommand(NewCLIHubTest().NewCommand()) - cmd.AddCommand(NewCLINotifications().NewCommand()) - cmd.AddCommand(NewCLISupport().NewCommand()) - cmd.AddCommand(NewCLIPapi().NewCommand()) - cmd.AddCommand(NewCLICollection().NewCommand()) - cmd.AddCommand(NewCLIParser().NewCommand()) - cmd.AddCommand(NewCLIScenario().NewCommand()) - cmd.AddCommand(NewCLIPostOverflow().NewCommand()) - cmd.AddCommand(NewCLIContext().NewCommand()) - cmd.AddCommand(NewCLIAppsecConfig().NewCommand()) - cmd.AddCommand(NewCLIAppsecRule().NewCommand()) - - if fflag.CscliSetup.IsEnabled() { - cmd.AddCommand(NewSetupCmd()) - } - - if fflag.PapiClient.IsEnabled() { - cmd.AddCommand(NewCLIPapi().NewCommand()) +func main() { + cmd, err := newCliRoot().NewCommand() + if err != nil { + log.Fatal(err) } if err := cmd.Execute(); err != nil { diff --git a/cmd/crowdsec-cli/messages.go b/cmd/crowdsec-cli/messages.go deleted file mode 100644 index 02f051601e4..00000000000 --- a/cmd/crowdsec-cli/messages.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import ( - "fmt" - "runtime" -) - -// ReloadMessage returns a description of the task required to reload -// the crowdsec configuration, according to the operating system. -func ReloadMessage() string { - var msg string - - switch runtime.GOOS { - case "windows": - msg = "Please restart the crowdsec service" - case "freebsd": - msg = `Run 'sudo service crowdsec reload'` - default: - msg = `Run 'sudo systemctl reload crowdsec'` - } - - return fmt.Sprintf("%s for the new configuration to be effective.", msg) -} diff --git a/cmd/crowdsec-cli/metrics.go b/cmd/crowdsec-cli/metrics.go deleted file mode 100644 index 5b24dc84c91..00000000000 --- a/cmd/crowdsec-cli/metrics.go +++ /dev/null @@ -1,355 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "time" - - "github.com/fatih/color" - dto "github.com/prometheus/client_model/go" - "github.com/prometheus/prom2json" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/go-cs-lib/trace" -) - -// FormatPrometheusMetrics is a complete rip from prom2json -func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error { - mfChan := make(chan *dto.MetricFamily, 1024) - errChan := make(chan error, 1) - - // Start with the DefaultTransport for sane defaults. - transport := http.DefaultTransport.(*http.Transport).Clone() - // Conservatively disable HTTP keep-alives as this program will only - // ever need a single HTTP request. - transport.DisableKeepAlives = true - // Timeout early if the server doesn't even return the headers. - transport.ResponseHeaderTimeout = time.Minute - go func() { - defer trace.CatchPanic("crowdsec/ShowPrometheus") - err := prom2json.FetchMetricFamilies(url, mfChan, transport) - if err != nil { - errChan <- fmt.Errorf("failed to fetch prometheus metrics: %w", err) - return - } - errChan <- nil - }() - - result := []*prom2json.Family{} - for mf := range mfChan { - result = append(result, prom2json.NewFamily(mf)) - } - - if err := <-errChan; err != nil { - return err - } - - log.Debugf("Finished reading prometheus output, %d entries", len(result)) - /*walk*/ - lapi_decisions_stats := map[string]struct { - NonEmpty int - Empty int - }{} - acquis_stats := map[string]map[string]int{} - parsers_stats := map[string]map[string]int{} - buckets_stats := map[string]map[string]int{} - lapi_stats := map[string]map[string]int{} - lapi_machine_stats := map[string]map[string]map[string]int{} - lapi_bouncer_stats := map[string]map[string]map[string]int{} - decisions_stats := map[string]map[string]map[string]int{} - appsec_engine_stats := map[string]map[string]int{} - appsec_rule_stats := map[string]map[string]map[string]int{} - alerts_stats := map[string]int{} - stash_stats := map[string]struct { - Type string - Count int - }{} - - for idx, fam := range result { - if !strings.HasPrefix(fam.Name, "cs_") { - continue - } - log.Tracef("round %d", idx) - for _, m := range fam.Metrics { - metric, ok := m.(prom2json.Metric) - if !ok { - log.Debugf("failed to convert metric to prom2json.Metric") - continue - } - name, ok := metric.Labels["name"] - if !ok { - log.Debugf("no name in Metric %v", metric.Labels) - } - source, ok := metric.Labels["source"] - if !ok { - log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name) - } else { - if srctype, ok := metric.Labels["type"]; ok { - source = srctype + ":" + source - } - } - - value := m.(prom2json.Metric).Value - machine := metric.Labels["machine"] - bouncer := metric.Labels["bouncer"] - - route := metric.Labels["route"] - method := metric.Labels["method"] - - reason := metric.Labels["reason"] - origin := metric.Labels["origin"] - action := metric.Labels["action"] - - mtype := metric.Labels["type"] - - fval, err := strconv.ParseFloat(value, 32) - if err != nil { - log.Errorf("Unexpected int value %s : %s", value, err) - } - ival := int(fval) - switch fam.Name { - /*buckets*/ - case "cs_bucket_created_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["instantiation"] += ival - case "cs_buckets": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["curr_count"] += ival - case "cs_bucket_overflowed_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["overflow"] += ival - case "cs_bucket_poured_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - buckets_stats[name]["pour"] += ival - acquis_stats[source]["pour"] += ival - case "cs_bucket_underflowed_total": - if _, ok := buckets_stats[name]; !ok { - buckets_stats[name] = make(map[string]int) - } - buckets_stats[name]["underflow"] += ival - /*acquis*/ - case "cs_parser_hits_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["reads"] += ival - case "cs_parser_hits_ok_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["parsed"] += ival - case "cs_parser_hits_ko_total": - if _, ok := acquis_stats[source]; !ok { - acquis_stats[source] = make(map[string]int) - } - acquis_stats[source]["unparsed"] += ival - case "cs_node_hits_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["hits"] += ival - case "cs_node_hits_ok_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["parsed"] += ival - case "cs_node_hits_ko_total": - if _, ok := parsers_stats[name]; !ok { - parsers_stats[name] = make(map[string]int) - } - parsers_stats[name]["unparsed"] += ival - case "cs_lapi_route_requests_total": - if _, ok := lapi_stats[route]; !ok { - lapi_stats[route] = make(map[string]int) - } - lapi_stats[route][method] += ival - case "cs_lapi_machine_requests_total": - if _, ok := lapi_machine_stats[machine]; !ok { - lapi_machine_stats[machine] = make(map[string]map[string]int) - } - if _, ok := lapi_machine_stats[machine][route]; !ok { - lapi_machine_stats[machine][route] = make(map[string]int) - } - lapi_machine_stats[machine][route][method] += ival - case "cs_lapi_bouncer_requests_total": - if _, ok := lapi_bouncer_stats[bouncer]; !ok { - lapi_bouncer_stats[bouncer] = make(map[string]map[string]int) - } - if _, ok := lapi_bouncer_stats[bouncer][route]; !ok { - lapi_bouncer_stats[bouncer][route] = make(map[string]int) - } - lapi_bouncer_stats[bouncer][route][method] += ival - case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total": - if _, ok := lapi_decisions_stats[bouncer]; !ok { - lapi_decisions_stats[bouncer] = struct { - NonEmpty int - Empty int - }{} - } - x := lapi_decisions_stats[bouncer] - if fam.Name == "cs_lapi_decisions_ko_total" { - x.Empty += ival - } else if fam.Name == "cs_lapi_decisions_ok_total" { - x.NonEmpty += ival - } - lapi_decisions_stats[bouncer] = x - case "cs_active_decisions": - if _, ok := decisions_stats[reason]; !ok { - decisions_stats[reason] = make(map[string]map[string]int) - } - if _, ok := decisions_stats[reason][origin]; !ok { - decisions_stats[reason][origin] = make(map[string]int) - } - decisions_stats[reason][origin][action] += ival - case "cs_alerts": - /*if _, ok := alerts_stats[scenario]; !ok { - alerts_stats[scenario] = make(map[string]int) - }*/ - alerts_stats[reason] += ival - case "cs_cache_size": - stash_stats[name] = struct { - Type string - Count int - }{Type: mtype, Count: ival} - case "cs_appsec_reqs_total": - if _, ok := appsec_engine_stats[metric.Labels["appsec_engine"]]; !ok { - appsec_engine_stats[metric.Labels["appsec_engine"]] = make(map[string]int, 0) - } - appsec_engine_stats[metric.Labels["appsec_engine"]]["processed"] = ival - case "cs_appsec_block_total": - if _, ok := appsec_engine_stats[metric.Labels["appsec_engine"]]; !ok { - appsec_engine_stats[metric.Labels["appsec_engine"]] = make(map[string]int, 0) - } - appsec_engine_stats[metric.Labels["appsec_engine"]]["blocked"] = ival - case "cs_appsec_rule_hits": - appsecEngine := metric.Labels["appsec_engine"] - ruleID := metric.Labels["rule_name"] - if _, ok := appsec_rule_stats[appsecEngine]; !ok { - appsec_rule_stats[appsecEngine] = make(map[string]map[string]int, 0) - } - if _, ok := appsec_rule_stats[appsecEngine][ruleID]; !ok { - appsec_rule_stats[appsecEngine][ruleID] = make(map[string]int, 0) - } - appsec_rule_stats[appsecEngine][ruleID]["triggered"] = ival - default: - log.Debugf("unknown: %+v", fam.Name) - continue - } - } - } - - if formatType == "human" { - acquisStatsTable(out, acquis_stats) - bucketStatsTable(out, buckets_stats) - parserStatsTable(out, parsers_stats) - lapiStatsTable(out, lapi_stats) - lapiMachineStatsTable(out, lapi_machine_stats) - lapiBouncerStatsTable(out, lapi_bouncer_stats) - lapiDecisionStatsTable(out, lapi_decisions_stats) - decisionStatsTable(out, decisions_stats) - alertStatsTable(out, alerts_stats) - stashStatsTable(out, stash_stats) - appsecMetricsToTable(out, appsec_engine_stats) - appsecRulesToTable(out, appsec_rule_stats) - return nil - } - - stats := make(map[string]any) - - stats["acquisition"] = acquis_stats - stats["buckets"] = buckets_stats - stats["parsers"] = parsers_stats - stats["lapi"] = lapi_stats - stats["lapi_machine"] = lapi_machine_stats - stats["lapi_bouncer"] = lapi_bouncer_stats - stats["lapi_decisions"] = lapi_decisions_stats - stats["decisions"] = decisions_stats - stats["alerts"] = alerts_stats - stats["stash"] = stash_stats - - switch formatType { - case "json": - x, err := json.MarshalIndent(stats, "", " ") - if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) - } - out.Write(x) - case "raw": - x, err := yaml.Marshal(stats) - if err != nil { - return fmt.Errorf("failed to unmarshal metrics : %v", err) - } - out.Write(x) - default: - return fmt.Errorf("unknown format type %s", formatType) - } - - return nil -} - -var noUnit bool - -func runMetrics(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - url, err := flags.GetString("url") - if err != nil { - return err - } - - if url != "" { - csConfig.Cscli.PrometheusUrl = url - } - - noUnit, err = flags.GetBool("no-unit") - if err != nil { - return err - } - - if csConfig.Prometheus == nil { - return fmt.Errorf("prometheus section missing, can't show metrics") - } - - if !csConfig.Prometheus.Enabled { - return fmt.Errorf("prometheus is not enabled, can't show metrics") - } - - if err = FormatPrometheusMetrics(color.Output, csConfig.Cscli.PrometheusUrl, csConfig.Cscli.Output); err != nil { - return err - } - return nil -} - -func NewMetricsCmd() *cobra.Command { - cmdMetrics := &cobra.Command{ - Use: "metrics", - Short: "Display crowdsec prometheus metrics.", - Long: `Fetch metrics from the prometheus server and display them in a human-friendly way`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: runMetrics, - } - - flags := cmdMetrics.PersistentFlags() - flags.StringP("url", "u", "", "Prometheus url (http://:/metrics)") - flags.Bool("no-unit", false, "Show the real number instead of formatted with units") - - return cmdMetrics -} diff --git a/cmd/crowdsec-cli/metrics_table.go b/cmd/crowdsec-cli/metrics_table.go deleted file mode 100644 index 80b9cb6e435..00000000000 --- a/cmd/crowdsec-cli/metrics_table.go +++ /dev/null @@ -1,338 +0,0 @@ -package main - -import ( - "fmt" - "io" - "sort" - - "github.com/aquasecurity/table" - log "github.com/sirupsen/logrus" -) - -func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int) int { - // stats: machine -> route -> method -> count - - // sort keys to keep consistent order when printing - machineKeys := []string{} - for k := range stats { - machineKeys = append(machineKeys, k) - } - sort.Strings(machineKeys) - - numRows := 0 - for _, machine := range machineKeys { - // oneRow: route -> method -> count - machineRow := stats[machine] - for routeName, route := range machineRow { - for methodName, count := range route { - row := []string{ - machine, - routeName, - methodName, - } - if count != 0 { - row = append(row, fmt.Sprintf("%d", count)) - } else { - row = append(row, "-") - } - t.AddRow(row...) - numRows++ - } - } - } - return numRows -} - -func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string) (int, error) { - if t == nil { - return 0, fmt.Errorf("nil table") - } - // sort keys to keep consistent order when printing - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats, ok := stats[alabel] - if !ok { - continue - } - row := []string{ - alabel, - } - for _, sl := range keys { - if v, ok := astats[sl]; ok && v != 0 { - numberToShow := fmt.Sprintf("%d", v) - if !noUnit { - numberToShow = formatNumber(v) - } - - row = append(row, numberToShow) - } else { - row = append(row, "-") - } - } - t.AddRow(row...) - numRows++ - } - return numRows, nil -} - -func bucketStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bucket", "Current Count", "Overflows", "Instantiated", "Poured", "Expired") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting bucket stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nBucket Metrics:") - t.Render() - } -} - -func acquisStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"reads", "parsed", "unparsed", "pour"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting acquis stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nAcquisition Metrics:") - t.Render() - } -} - -func appsecMetricsToTable(out io.Writer, metrics map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Appsec Engine", "Processed", "Blocked") - t.SetAlignment(table.AlignLeft, table.AlignLeft) - keys := []string{"processed", "blocked"} - if numRows, err := metricsToTable(t, metrics, keys); err != nil { - log.Warningf("while collecting appsec stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nAppsec Metrics:") - t.Render() - } -} - -func appsecRulesToTable(out io.Writer, metrics map[string]map[string]map[string]int) { - for appsecEngine, appsecEngineRulesStats := range metrics { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Rule ID", "Triggered") - t.SetAlignment(table.AlignLeft, table.AlignLeft) - keys := []string{"triggered"} - if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys); err != nil { - log.Warningf("while collecting appsec rules stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, fmt.Sprintf("\nAppsec '%s' Rules Metrics:", appsecEngine)) - t.Render() - } - } - -} - -func parserStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - keys := []string{"hits", "parsed", "unparsed"} - - if numRows, err := metricsToTable(t, stats, keys); err != nil { - log.Warningf("while collecting parsers stats: %s", err) - } else if numRows > 0 { - renderTableTitle(out, "\nParser Metrics:") - t.Render() - } -} - -func stashStatsTable(out io.Writer, stats map[string]struct { - Type string - Count int -}) { - - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Name", "Type", "Items") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats := stats[alabel] - - row := []string{ - alabel, - astats.Type, - fmt.Sprintf("%d", astats.Count), - } - t.AddRow(row...) - numRows++ - } - if numRows > 0 { - renderTableTitle(out, "\nParser Stash Metrics:") - t.Render() - } -} - -func lapiStatsTable(out io.Writer, stats map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - // unfortunately, we can't reuse metricsToTable as the structure is too different :/ - sortedKeys := []string{} - for k := range stats { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - numRows := 0 - for _, alabel := range sortedKeys { - astats := stats[alabel] - - subKeys := []string{} - for skey := range astats { - subKeys = append(subKeys, skey) - } - sort.Strings(subKeys) - - for _, sl := range subKeys { - row := []string{ - alabel, - sl, - fmt.Sprintf("%d", astats[sl]), - } - t.AddRow(row...) - numRows++ - } - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Metrics:") - t.Render() - } -} - -func lapiMachineStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Machine", "Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := lapiMetricsToTable(t, stats) - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Machines Metrics:") - t.Render() - } -} - -func lapiBouncerStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bouncer", "Route", "Method", "Hits") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := lapiMetricsToTable(t, stats) - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Metrics:") - t.Render() - } -} - -func lapiDecisionStatsTable(out io.Writer, stats map[string]struct { - NonEmpty int - Empty int -}, -) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Bouncer", "Empty answers", "Non-empty answers") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := 0 - for bouncer, hits := range stats { - t.AddRow( - bouncer, - fmt.Sprintf("%d", hits.Empty), - fmt.Sprintf("%d", hits.NonEmpty), - ) - numRows++ - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Bouncers Decisions:") - t.Render() - } -} - -func decisionStatsTable(out io.Writer, stats map[string]map[string]map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Reason", "Origin", "Action", "Count") - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - numRows := 0 - for reason, origins := range stats { - for origin, actions := range origins { - for action, hits := range actions { - t.AddRow( - reason, - origin, - action, - fmt.Sprintf("%d", hits), - ) - numRows++ - } - } - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Decisions:") - t.Render() - } -} - -func alertStatsTable(out io.Writer, stats map[string]int) { - t := newTable(out) - t.SetRowLines(false) - t.SetHeaders("Reason", "Count") - t.SetAlignment(table.AlignLeft, table.AlignLeft) - - numRows := 0 - for scenario, hits := range stats { - t.AddRow( - scenario, - fmt.Sprintf("%d", hits), - ) - numRows++ - } - - if numRows > 0 { - renderTableTitle(out, "\nLocal API Alerts:") - t.Render() - } -} diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go deleted file mode 100644 index 606d8b415a0..00000000000 --- a/cmd/crowdsec-cli/papi.go +++ /dev/null @@ -1,145 +0,0 @@ -package main - -import ( - "time" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/ptr" - - "github.com/crowdsecurity/crowdsec/pkg/apiserver" - "github.com/crowdsecurity/crowdsec/pkg/database" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" -) - -type cliPapi struct {} - -func NewCLIPapi() *cliPapi { - return &cliPapi{} -} - -func (cli cliPapi) NewCommand() *cobra.Command { - var cmd = &cobra.Command{ - Use: "papi [action]", - Short: "Manage interaction with Polling API (PAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := require.LAPI(csConfig); err != nil { - return err - } - if err := require.CAPI(csConfig); err != nil { - return err - } - if err := require.PAPI(csConfig); err != nil { - return err - } - return nil - }, - } - - cmd.AddCommand(cli.NewStatusCmd()) - cmd.AddCommand(cli.NewSyncCmd()) - - return cmd -} - -func (cli cliPapi) NewStatusCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "Get status of the Polling API", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Fatalf("unable to initialize database client : %s", err) - } - - apic, err := apiserver.NewAPIC(csConfig.API.Server.OnlineClient, dbClient, csConfig.API.Server.ConsoleConfig, csConfig.API.Server.CapiWhitelists) - - if err != nil { - log.Fatalf("unable to initialize API client : %s", err) - } - - papi, err := apiserver.NewPAPI(apic, dbClient, csConfig.API.Server.ConsoleConfig, log.GetLevel()) - - if err != nil { - log.Fatalf("unable to initialize PAPI client : %s", err) - } - - perms, err := papi.GetPermissions() - - if err != nil { - log.Fatalf("unable to get PAPI permissions: %s", err) - } - var lastTimestampStr *string - lastTimestampStr, err = dbClient.GetConfigItem(apiserver.PapiPullKey) - if err != nil { - lastTimestampStr = ptr.Of("never") - } - log.Infof("You can successfully interact with Polling API (PAPI)") - log.Infof("Console plan: %s", perms.Plan) - log.Infof("Last order received: %s", *lastTimestampStr) - - log.Infof("PAPI subscriptions:") - for _, sub := range perms.Categories { - log.Infof(" - %s", sub) - } - }, - } - - return cmd -} - -func (cli cliPapi) NewSyncCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "sync", - Short: "Sync with the Polling API, pulling all non-expired orders for the instance", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - t := tomb.Tomb{} - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Fatalf("unable to initialize database client : %s", err) - } - - apic, err := apiserver.NewAPIC(csConfig.API.Server.OnlineClient, dbClient, csConfig.API.Server.ConsoleConfig, csConfig.API.Server.CapiWhitelists) - - if err != nil { - log.Fatalf("unable to initialize API client : %s", err) - } - - t.Go(apic.Push) - - papi, err := apiserver.NewPAPI(apic, dbClient, csConfig.API.Server.ConsoleConfig, log.GetLevel()) - - if err != nil { - log.Fatalf("unable to initialize PAPI client : %s", err) - } - t.Go(papi.SyncDecisions) - - err = papi.PullOnce(time.Time{}, true) - - if err != nil { - log.Fatalf("unable to sync decisions: %s", err) - } - - log.Infof("Sending acknowledgements to CAPI") - - apic.Shutdown() - papi.Shutdown() - t.Wait() - time.Sleep(5 * time.Second) //FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done - - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/reload/reload.go b/cmd/crowdsec-cli/reload/reload.go new file mode 100644 index 00000000000..fe03af1ea79 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload.go @@ -0,0 +1,6 @@ +//go:build !windows && !freebsd && !linux + +package reload + +// generic message since we don't know the platform +const Message = "Please reload the crowdsec process for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_freebsd.go b/cmd/crowdsec-cli/reload/reload_freebsd.go new file mode 100644 index 00000000000..0dac99f2315 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_freebsd.go @@ -0,0 +1,4 @@ +package reload + +// actually sudo is not that popular on freebsd, but this will do +const Message = "Run 'sudo service crowdsec reload' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_linux.go b/cmd/crowdsec-cli/reload/reload_linux.go new file mode 100644 index 00000000000..fbe16e5f168 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_linux.go @@ -0,0 +1,4 @@ +package reload + +// assume systemd, although gentoo and others may differ +const Message = "Run 'sudo systemctl reload crowdsec' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_windows.go b/cmd/crowdsec-cli/reload/reload_windows.go new file mode 100644 index 00000000000..88642425ae2 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_windows.go @@ -0,0 +1,3 @@ +package reload + +const Message = "Please restart the crowdsec service for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/require/branch.go b/cmd/crowdsec-cli/require/branch.go index 6fcaaacea2d..09acc0fef8a 100644 --- a/cmd/crowdsec-cli/require/branch.go +++ b/cmd/crowdsec-cli/require/branch.go @@ -3,54 +3,100 @@ package require // Set the appropriate hub branch according to config settings and crowdsec version import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + log "github.com/sirupsen/logrus" "golang.org/x/mod/semver" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwversion" ) -func chooseBranch(cfg *csconfig.Config) string { +// lookupLatest returns the latest crowdsec version based on github +func lookupLatest(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + url := "https://version.crowdsec.net/latest" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("unable to create request for %s: %w", url, err) + } + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("unable to send request to %s: %w", url, err) + } + defer resp.Body.Close() + + latest := make(map[string]any) + + if err := json.NewDecoder(resp.Body).Decode(&latest); err != nil { + return "", fmt.Errorf("unable to decode response from %s: %w", url, err) + } + + if _, ok := latest["name"]; !ok { + return "", fmt.Errorf("unable to find 'name' key in response from %s", url) + } + + name, ok := latest["name"].(string) + if !ok { + return "", fmt.Errorf("unable to convert 'name' key to string in response from %s", url) + } + + return name, nil +} + +func chooseBranch(ctx context.Context, cfg *csconfig.Config) string { // this was set from config.yaml or flag if cfg.Cscli.HubBranch != "" { log.Debugf("Hub override from config: branch '%s'", cfg.Cscli.HubBranch) return cfg.Cscli.HubBranch } - latest, err := cwversion.Latest() + latest, err := lookupLatest(ctx) if err != nil { log.Warningf("Unable to retrieve latest crowdsec version: %s, using hub branch 'master'", err) return "master" } csVersion := cwversion.VersionStrip() - if csVersion == latest { - log.Debugf("Latest crowdsec version (%s), using hub branch 'master'", csVersion) + if csVersion == "" { + log.Warning("Crowdsec version is not set, using hub branch 'master'") return "master" } - // if current version is greater than the latest we are in pre-release - if semver.Compare(csVersion, latest) == 1 { - log.Debugf("Your current crowdsec version seems to be a pre-release (%s), using hub branch 'master'", csVersion) + if csVersion == latest { + log.Debugf("Latest crowdsec version (%s), using hub branch 'master'", version.String()) return "master" } - if csVersion == "" { - log.Warning("Crowdsec version is not set, using hub branch 'master'") + // if current version is greater than the latest we are in pre-release + if semver.Compare(csVersion, latest) == 1 { + log.Debugf("Your current crowdsec version seems to be a pre-release (%s), using hub branch 'master'", version.String()) return "master" } log.Warnf("A new CrowdSec release is available (%s). "+ "Your version is '%s'. Please update it to use new parsers/scenarios/collections.", latest, csVersion) + return csVersion } - // HubBranch sets the branch (in cscli config) and returns its value // It can be "master", or the branch corresponding to the current crowdsec version, or the value overridden in config/flag -func HubBranch(cfg *csconfig.Config) string { - branch := chooseBranch(cfg) +func HubBranch(ctx context.Context, cfg *csconfig.Config) string { + branch := chooseBranch(ctx, cfg) cfg.Cscli.HubBranch = branch diff --git a/cmd/crowdsec-cli/require/require.go b/cmd/crowdsec-cli/require/require.go index 0f5ce182d9a..191eee55bc5 100644 --- a/cmd/crowdsec-cli/require/require.go +++ b/cmd/crowdsec-cli/require/require.go @@ -1,6 +1,8 @@ package require import ( + "context" + "errors" "fmt" "io" @@ -8,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/database" ) func LAPI(c *csconfig.Config) error { @@ -16,7 +19,7 @@ func LAPI(c *csconfig.Config) error { } if c.DisableAPI { - return fmt.Errorf("local API is disabled -- this command must be run on the local API machine") + return errors.New("local API is disabled -- this command must be run on the local API machine") } return nil @@ -31,8 +34,16 @@ func CAPI(c *csconfig.Config) error { } func PAPI(c *csconfig.Config) error { + if err := CAPI(c); err != nil { + return err + } + + if err := CAPIRegistered(c); err != nil { + return err + } + if c.API.Server.OnlineClient.Credentials.PapiURL == "" { - return fmt.Errorf("no PAPI URL in configuration") + return errors.New("no PAPI URL in configuration") } return nil @@ -40,12 +51,21 @@ func PAPI(c *csconfig.Config) error { func CAPIRegistered(c *csconfig.Config) error { if c.API.Server.OnlineClient.Credentials == nil { - return fmt.Errorf("the Central API (CAPI) must be configured with 'cscli capi register'") + return errors.New("the Central API (CAPI) must be configured with 'cscli capi register'") } return nil } +func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) { + db, err := database.NewClient(ctx, dbcfg) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + return db, nil +} + func DB(c *csconfig.Config) error { if err := c.LoadDBConfig(true); err != nil { return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err) @@ -56,21 +76,21 @@ func DB(c *csconfig.Config) error { func Notifications(c *csconfig.Config) error { if c.ConfigPaths.NotificationDir == "" { - return fmt.Errorf("config_paths.notification_dir is not set in crowdsec config") + return errors.New("config_paths.notification_dir is not set in crowdsec config") } return nil } // RemoteHub returns the configuration required to download hub index and items: url, branch, etc. -func RemoteHub(c *csconfig.Config) *cwhub.RemoteHubCfg { +func RemoteHub(ctx context.Context, c *csconfig.Config) *cwhub.RemoteHubCfg { // set branch in config, and log if necessary - branch := HubBranch(c) + branch := HubBranch(ctx, c) urlTemplate := HubURLTemplate(c) remote := &cwhub.RemoteHubCfg{ Branch: branch, URLTemplate: urlTemplate, - IndexPath: ".index.json", + IndexPath: ".index.json", } return remote @@ -82,7 +102,7 @@ func Hub(c *csconfig.Config, remote *cwhub.RemoteHubCfg, logger *logrus.Logger) local := c.Hub if local == nil { - return nil, fmt.Errorf("you must configure cli before interacting with hub") + return nil, errors.New("you must configure cli before interacting with hub") } if logger == nil { @@ -90,8 +110,12 @@ func Hub(c *csconfig.Config, remote *cwhub.RemoteHubCfg, logger *logrus.Logger) logger.SetOutput(io.Discard) } - hub, err := cwhub.NewHub(local, remote, false, logger) + hub, err := cwhub.NewHub(local, remote, logger) if err != nil { + return nil, err + } + + if err := hub.Load(); err != nil { return nil, fmt.Errorf("failed to read Hub index: %w. Run 'sudo cscli hub update' to download the index again", err) } diff --git a/cmd/crowdsec-cli/setup.go b/cmd/crowdsec-cli/setup.go index 48dcee08905..66c0d71e777 100644 --- a/cmd/crowdsec-cli/setup.go +++ b/cmd/crowdsec-cli/setup.go @@ -1,332 +1,18 @@ +//go:build !no_cscli_setup package main import ( - "bytes" - "fmt" - "os" - "os/exec" - - goccyyaml "github.com/goccy/go-yaml" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/setup" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisetup" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -// NewSetupCmd defines the "cscli setup" command. -func NewSetupCmd() *cobra.Command { - cmdSetup := &cobra.Command{ - Use: "setup", - Short: "Tools to configure crowdsec", - Long: "Manage hub configuration and service detection", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - } - - // - // cscli setup detect - // - { - cmdSetupDetect := &cobra.Command{ - Use: "detect", - Short: "detect running services, generate a setup file", - DisableAutoGenTag: true, - RunE: runSetupDetect, - } - - defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") - - flags := cmdSetupDetect.Flags() - flags.String("detect-config", defaultServiceDetect, "path to service detection configuration") - flags.Bool("list-supported-services", false, "do not detect; only print supported services") - flags.StringSlice("force-unit", nil, "force detection of a systemd unit (can be repeated)") - flags.StringSlice("force-process", nil, "force detection of a running process (can be repeated)") - flags.StringSlice("skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") - flags.String("force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") - flags.String("force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") - flags.String("force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") - flags.Bool("snub-systemd", false, "don't use systemd, even if available") - flags.Bool("yaml", false, "output yaml, not json") - cmdSetup.AddCommand(cmdSetupDetect) - } - - // - // cscli setup install-hub - // - { - cmdSetupInstallHub := &cobra.Command{ - Use: "install-hub [setup_file] [flags]", - Short: "install items from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupInstallHub, - } - - flags := cmdSetupInstallHub.Flags() - flags.Bool("dry-run", false, "don't install anything; print out what would have been") - cmdSetup.AddCommand(cmdSetupInstallHub) - } - - // - // cscli setup datasources - // - { - cmdSetupDataSources := &cobra.Command{ - Use: "datasources [setup_file] [flags]", - Short: "generate datasource (acquisition) configuration from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupDataSources, - } - - flags := cmdSetupDataSources.Flags() - flags.String("to-dir", "", "write the configuration to a directory, in multiple files") - cmdSetup.AddCommand(cmdSetupDataSources) - } - - // - // cscli setup validate - // - { - cmdSetupValidate := &cobra.Command{ - Use: "validate [setup_file]", - Short: "validate a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: runSetupValidate, - } - - cmdSetup.AddCommand(cmdSetupValidate) - } - - return cmdSetup -} - -func runSetupDetect(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - detectConfigFile, err := flags.GetString("detect-config") - if err != nil { - return err - } - - var detectReader *os.File - - switch detectConfigFile { - case "-": - log.Tracef("Reading detection rules from stdin") - detectReader = os.Stdin - default: - log.Tracef("Reading detection rules: %s", detectConfigFile) - detectReader, err = os.Open(detectConfigFile) - if err != nil { - return err - } - } - - listSupportedServices, err := flags.GetBool("list-supported-services") - if err != nil { - return err - } - - forcedUnits, err := flags.GetStringSlice("force-unit") - if err != nil { - return err - } - - forcedProcesses, err := flags.GetStringSlice("force-process") - if err != nil { - return err - } - - forcedOSFamily, err := flags.GetString("force-os-family") - if err != nil { - return err - } - - forcedOSID, err := flags.GetString("force-os-id") - if err != nil { - return err - } - - forcedOSVersion, err := flags.GetString("force-os-version") - if err != nil { - return err - } - - skipServices, err := flags.GetStringSlice("skip-service") - if err != nil { - return err - } - - snubSystemd, err := flags.GetBool("snub-systemd") - if err != nil { - return err - } - - if !snubSystemd { - _, err := exec.LookPath("systemctl") - if err != nil { - log.Debug("systemctl not available: snubbing systemd") - snubSystemd = true - } - } - - outYaml, err := flags.GetBool("yaml") - if err != nil { - return err - } - - if forcedOSFamily == "" && forcedOSID != "" { - log.Debug("force-os-id is set: force-os-family defaults to 'linux'") - forcedOSFamily = "linux" - } - - if listSupportedServices { - supported, err := setup.ListSupported(detectReader) - if err != nil { - return err - } - - for _, svc := range supported { - fmt.Println(svc) - } - - return nil - } - - opts := setup.DetectOptions{ - ForcedUnits: forcedUnits, - ForcedProcesses: forcedProcesses, - ForcedOS: setup.ExprOS{ - Family: forcedOSFamily, - ID: forcedOSID, - RawVersion: forcedOSVersion, - }, - SkipServices: skipServices, - SnubSystemd: snubSystemd, - } - - hubSetup, err := setup.Detect(detectReader, opts) - if err != nil { - return fmt.Errorf("detecting services: %w", err) - } - - setup, err := setupAsString(hubSetup, outYaml) - if err != nil { - return err - } - fmt.Println(setup) - - return nil -} - -func setupAsString(cs setup.Setup, outYaml bool) (string, error) { - var ( - ret []byte - err error - ) - - wrap := func(err error) error { - return fmt.Errorf("while marshaling setup: %w", err) - } - - indentLevel := 2 - buf := &bytes.Buffer{} - enc := yaml.NewEncoder(buf) - enc.SetIndent(indentLevel) - - if err = enc.Encode(cs); err != nil { - return "", wrap(err) - } - - if err = enc.Close(); err != nil { - return "", wrap(err) - } - - ret = buf.Bytes() - - if !outYaml { - // take a general approach to output json, so we avoid the - // double tags in the structures and can use go-yaml features - // missing from the json package - ret, err = goccyyaml.YAMLToJSON(ret) - if err != nil { - return "", wrap(err) - } - } - - return string(ret), nil -} - -func runSetupDataSources(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - fromFile := args[0] - - toDir, err := flags.GetString("to-dir") - if err != nil { - return err - } - - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading setup file: %w", err) - } - - output, err := setup.DataSources(input, toDir) - if err != nil { - return err - } - - if toDir == "" { - fmt.Println(output) - } - - return nil -} - -func runSetupInstallHub(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - fromFile := args[0] - - dryRun, err := flags.GetBool("dry-run") - if err != nil { - return err - } - - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading file %s: %w", fromFile, err) - } - - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), log.StandardLogger()) - if err != nil { - return err - } - - if err = setup.InstallHubItems(hub, input, dryRun); err != nil { - return err - } - - return nil -} - -func runSetupValidate(cmd *cobra.Command, args []string) error { - fromFile := args[0] - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading stdin: %w", err) - } - - if err = setup.Validate(input); err != nil { - fmt.Printf("%v\n", err) - return fmt.Errorf("invalid setup file") +func (cli *cliRoot) addSetup(cmd *cobra.Command) { + if fflag.CscliSetup.IsEnabled() { + cmd.AddCommand(clisetup.New(cli.cfg).NewCommand()) } - return nil + component.Register("cscli_setup") } diff --git a/cmd/crowdsec-cli/setup_stub.go b/cmd/crowdsec-cli/setup_stub.go new file mode 100644 index 00000000000..e001f93c797 --- /dev/null +++ b/cmd/crowdsec-cli/setup_stub.go @@ -0,0 +1,9 @@ +//go:build no_cscli_setup +package main + +import ( + "github.com/spf13/cobra" +) + +func (cli *cliRoot) addSetup(_ *cobra.Command) { +} diff --git a/cmd/crowdsec-cli/simulation.go b/cmd/crowdsec-cli/simulation.go deleted file mode 100644 index 99dac7c17f2..00000000000 --- a/cmd/crowdsec-cli/simulation.go +++ /dev/null @@ -1,269 +0,0 @@ -package main - -import ( - "fmt" - "os" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" - "slices" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -type cliSimulation struct{} - -func NewCLISimulation() *cliSimulation { - return &cliSimulation{} -} - -func (cli cliSimulation) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "simulation [command]", - Short: "Manage simulation status of scenarios", - Example: `cscli simulation status -cscli simulation enable crowdsecurity/ssh-bf -cscli simulation disable crowdsecurity/ssh-bf`, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if err := csConfig.LoadSimulation(); err != nil { - log.Fatal(err) - } - if csConfig.Cscli.SimulationConfig == nil { - return fmt.Errorf("no simulation configured") - } - return nil - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - if cmd.Name() != "status" { - log.Infof(ReloadMessage()) - } - }, - } - cmd.Flags().SortFlags = false - cmd.PersistentFlags().SortFlags = false - - cmd.AddCommand(cli.NewEnableCmd()) - cmd.AddCommand(cli.NewDisableCmd()) - cmd.AddCommand(cli.NewStatusCmd()) - - return cmd -} - -func (cli cliSimulation) NewEnableCmd() *cobra.Command { - var forceGlobalSimulation bool - - cmd := &cobra.Command{ - Use: "enable [scenario] [-global]", - Short: "Enable the simulation, globally or on specified scenarios", - Example: `cscli simulation enable`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - log.Fatal(err) - } - - if len(args) > 0 { - for _, scenario := range args { - var item = hub.GetItem(cwhub.SCENARIOS, scenario) - if item == nil { - log.Errorf("'%s' doesn't exist or is not a scenario", scenario) - continue - } - if !item.State.Installed { - log.Warningf("'%s' isn't enabled", scenario) - } - isExcluded := slices.Contains(csConfig.Cscli.SimulationConfig.Exclusions, scenario) - if *csConfig.Cscli.SimulationConfig.Simulation && !isExcluded { - log.Warning("global simulation is already enabled") - continue - } - if !*csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - log.Warningf("simulation for '%s' already enabled", scenario) - continue - } - if *csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - if err := removeFromExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation enabled for '%s'", scenario) - continue - } - if err := addToExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' enabled", scenario) - } - if err := dumpSimulationFile(); err != nil { - log.Fatalf("simulation enable: %s", err) - } - } else if forceGlobalSimulation { - if err := enableGlobalSimulation(); err != nil { - log.Fatalf("unable to enable global simulation mode : %s", err) - } - } else { - printHelp(cmd) - } - }, - } - cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Enable global simulation (reverse mode)") - - return cmd -} - -func (cli cliSimulation) NewDisableCmd() *cobra.Command { - var forceGlobalSimulation bool - - cmd := &cobra.Command{ - Use: "disable [scenario]", - Short: "Disable the simulation mode. Disable only specified scenarios", - Example: `cscli simulation disable`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if len(args) > 0 { - for _, scenario := range args { - isExcluded := slices.Contains(csConfig.Cscli.SimulationConfig.Exclusions, scenario) - if !*csConfig.Cscli.SimulationConfig.Simulation && !isExcluded { - log.Warningf("%s isn't in simulation mode", scenario) - continue - } - if !*csConfig.Cscli.SimulationConfig.Simulation && isExcluded { - if err := removeFromExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' disabled", scenario) - continue - } - if isExcluded { - log.Warningf("simulation mode is enabled but is already disable for '%s'", scenario) - continue - } - if err := addToExclusion(scenario); err != nil { - log.Fatal(err) - } - log.Printf("simulation mode for '%s' disabled", scenario) - } - if err := dumpSimulationFile(); err != nil { - log.Fatalf("simulation disable: %s", err) - } - } else if forceGlobalSimulation { - if err := disableGlobalSimulation(); err != nil { - log.Fatalf("unable to disable global simulation mode : %s", err) - } - } else { - printHelp(cmd) - } - }, - } - cmd.Flags().BoolVarP(&forceGlobalSimulation, "global", "g", false, "Disable global simulation (reverse mode)") - - return cmd -} - -func (cli cliSimulation) NewStatusCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "Show simulation mode status", - Example: `cscli simulation status`, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - if err := simulationStatus(); err != nil { - log.Fatal(err) - } - }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - }, - } - - return cmd -} - -func addToExclusion(name string) error { - csConfig.Cscli.SimulationConfig.Exclusions = append(csConfig.Cscli.SimulationConfig.Exclusions, name) - return nil -} - -func removeFromExclusion(name string) error { - index := slices.Index(csConfig.Cscli.SimulationConfig.Exclusions, name) - - // Remove element from the slice - csConfig.Cscli.SimulationConfig.Exclusions[index] = csConfig.Cscli.SimulationConfig.Exclusions[len(csConfig.Cscli.SimulationConfig.Exclusions)-1] - csConfig.Cscli.SimulationConfig.Exclusions[len(csConfig.Cscli.SimulationConfig.Exclusions)-1] = "" - csConfig.Cscli.SimulationConfig.Exclusions = csConfig.Cscli.SimulationConfig.Exclusions[:len(csConfig.Cscli.SimulationConfig.Exclusions)-1] - - return nil -} - -func enableGlobalSimulation() error { - csConfig.Cscli.SimulationConfig.Simulation = new(bool) - *csConfig.Cscli.SimulationConfig.Simulation = true - csConfig.Cscli.SimulationConfig.Exclusions = []string{} - - if err := dumpSimulationFile(); err != nil { - log.Fatalf("unable to dump simulation file: %s", err) - } - - log.Printf("global simulation: enabled") - - return nil -} - -func dumpSimulationFile() error { - newConfigSim, err := yaml.Marshal(csConfig.Cscli.SimulationConfig) - if err != nil { - return fmt.Errorf("unable to marshal simulation configuration: %s", err) - } - err = os.WriteFile(csConfig.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) - if err != nil { - return fmt.Errorf("write simulation config in '%s' failed: %s", csConfig.ConfigPaths.SimulationFilePath, err) - } - log.Debugf("updated simulation file %s", csConfig.ConfigPaths.SimulationFilePath) - - return nil -} - -func disableGlobalSimulation() error { - csConfig.Cscli.SimulationConfig.Simulation = new(bool) - *csConfig.Cscli.SimulationConfig.Simulation = false - - csConfig.Cscli.SimulationConfig.Exclusions = []string{} - newConfigSim, err := yaml.Marshal(csConfig.Cscli.SimulationConfig) - if err != nil { - return fmt.Errorf("unable to marshal new simulation configuration: %s", err) - } - err = os.WriteFile(csConfig.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) - if err != nil { - return fmt.Errorf("unable to write new simulation config in '%s' : %s", csConfig.ConfigPaths.SimulationFilePath, err) - } - - log.Printf("global simulation: disabled") - return nil -} - -func simulationStatus() error { - if csConfig.Cscli.SimulationConfig == nil { - log.Printf("global simulation: disabled (configuration file is missing)") - return nil - } - if *csConfig.Cscli.SimulationConfig.Simulation { - log.Println("global simulation: enabled") - if len(csConfig.Cscli.SimulationConfig.Exclusions) > 0 { - log.Println("Scenarios not in simulation mode :") - for _, scenario := range csConfig.Cscli.SimulationConfig.Exclusions { - log.Printf(" - %s", scenario) - } - } - } else { - log.Println("global simulation: disabled") - if len(csConfig.Cscli.SimulationConfig.Exclusions) > 0 { - log.Println("Scenarios in simulation mode :") - for _, scenario := range csConfig.Cscli.SimulationConfig.Exclusions { - log.Printf(" - %s", scenario) - } - } - } - return nil -} diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go deleted file mode 100644 index 47768e7c2ee..00000000000 --- a/cmd/crowdsec-cli/support.go +++ /dev/null @@ -1,439 +0,0 @@ -package main - -import ( - "archive/zip" - "bytes" - "context" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/blackfireio/osinfo" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/go-cs-lib/version" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/fflag" - "github.com/crowdsecurity/crowdsec/pkg/models" -) - -const ( - SUPPORT_METRICS_HUMAN_PATH = "metrics/metrics.human" - SUPPORT_METRICS_PROMETHEUS_PATH = "metrics/metrics.prometheus" - SUPPORT_VERSION_PATH = "version.txt" - SUPPORT_FEATURES_PATH = "features.txt" - SUPPORT_OS_INFO_PATH = "osinfo.txt" - SUPPORT_PARSERS_PATH = "hub/parsers.txt" - SUPPORT_SCENARIOS_PATH = "hub/scenarios.txt" - SUPPORT_CONTEXTS_PATH = "hub/scenarios.txt" - SUPPORT_COLLECTIONS_PATH = "hub/collections.txt" - SUPPORT_POSTOVERFLOWS_PATH = "hub/postoverflows.txt" - SUPPORT_BOUNCERS_PATH = "lapi/bouncers.txt" - SUPPORT_AGENTS_PATH = "lapi/agents.txt" - SUPPORT_CROWDSEC_CONFIG_PATH = "config/crowdsec.yaml" - SUPPORT_LAPI_STATUS_PATH = "lapi_status.txt" - SUPPORT_CAPI_STATUS_PATH = "capi_status.txt" - SUPPORT_ACQUISITION_CONFIG_BASE_PATH = "config/acquis/" - SUPPORT_CROWDSEC_PROFILE_PATH = "config/profiles.yaml" -) - -// from https://github.com/acarl005/stripansi -var reStripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") - -func stripAnsiString(str string) string { - // the byte version doesn't strip correctly - return reStripAnsi.ReplaceAllString(str, "") -} - -func collectMetrics() ([]byte, []byte, error) { - log.Info("Collecting prometheus metrics") - - if csConfig.Cscli.PrometheusUrl == "" { - log.Warn("No Prometheus URL configured, metrics will not be collected") - return nil, nil, fmt.Errorf("prometheus_uri is not set") - } - - humanMetrics := bytes.NewBuffer(nil) - err := FormatPrometheusMetrics(humanMetrics, csConfig.Cscli.PrometheusUrl, "human") - - if err != nil { - return nil, nil, fmt.Errorf("could not fetch promtheus metrics: %s", err) - } - - req, err := http.NewRequest(http.MethodGet, csConfig.Cscli.PrometheusUrl, nil) - if err != nil { - return nil, nil, fmt.Errorf("could not create requests to prometheus endpoint: %s", err) - } - client := &http.Client{} - resp, err := client.Do(req) - - if err != nil { - return nil, nil, fmt.Errorf("could not get metrics from prometheus endpoint: %s", err) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("could not read metrics from prometheus endpoint: %s", err) - } - - return humanMetrics.Bytes(), body, nil -} - -func collectVersion() []byte { - log.Info("Collecting version") - return []byte(cwversion.ShowStr()) -} - -func collectFeatures() []byte { - log.Info("Collecting feature flags") - enabledFeatures := fflag.Crowdsec.GetEnabledFeatures() - - w := bytes.NewBuffer(nil) - for _, k := range enabledFeatures { - fmt.Fprintf(w, "%s\n", k) - } - return w.Bytes() -} - -func collectOSInfo() ([]byte, error) { - log.Info("Collecting OS info") - info, err := osinfo.GetOSInfo() - - if err != nil { - return nil, err - } - - w := bytes.NewBuffer(nil) - w.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture)) - w.WriteString(fmt.Sprintf("Family: %s\n", info.Family)) - w.WriteString(fmt.Sprintf("ID: %s\n", info.ID)) - w.WriteString(fmt.Sprintf("Name: %s\n", info.Name)) - w.WriteString(fmt.Sprintf("Codename: %s\n", info.Codename)) - w.WriteString(fmt.Sprintf("Version: %s\n", info.Version)) - w.WriteString(fmt.Sprintf("Build: %s\n", info.Build)) - - return w.Bytes(), nil -} - -func collectHubItems(hub *cwhub.Hub, itemType string) []byte { - var err error - - out := bytes.NewBuffer(nil) - log.Infof("Collecting %s list", itemType) - - items := make(map[string][]*cwhub.Item) - - if items[itemType], err = selectItems(hub, itemType, nil, true); err != nil { - log.Warnf("could not collect %s list: %s", itemType, err) - } - - if err := listItems(out, []string{itemType}, items, false); err != nil { - log.Warnf("could not collect %s list: %s", itemType, err) - } - return out.Bytes() -} - -func collectBouncers(dbClient *database.Client) ([]byte, error) { - out := bytes.NewBuffer(nil) - bouncers, err := dbClient.ListBouncers() - if err != nil { - return nil, fmt.Errorf("unable to list bouncers: %s", err) - } - getBouncersTable(out, bouncers) - return out.Bytes(), nil -} - -func collectAgents(dbClient *database.Client) ([]byte, error) { - out := bytes.NewBuffer(nil) - err := getAgents(out, dbClient) - if err != nil { - return nil, err - } - return out.Bytes(), nil -} - -func collectAPIStatus(login string, password string, endpoint string, prefix string, hub *cwhub.Hub) []byte { - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil { - return []byte("No agent credentials found, are we LAPI ?") - } - pwd := strfmt.Password(password) - apiurl, err := url.Parse(endpoint) - - if err != nil { - return []byte(fmt.Sprintf("cannot parse API URL: %s", err)) - } - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return []byte(fmt.Sprintf("could not collect scenarios: %s", err)) - } - - Client, err = apiclient.NewDefaultClient(apiurl, - prefix, - fmt.Sprintf("crowdsec/%s", version.String()), - nil) - if err != nil { - return []byte(fmt.Sprintf("could not init client: %s", err)) - } - t := models.WatcherAuthRequest{ - MachineID: &login, - Password: &pwd, - Scenarios: scenarios, - } - - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return []byte(fmt.Sprintf("Could not authenticate to API: %s", err)) - } else { - return []byte("Successfully authenticated to LAPI") - } -} - -func collectCrowdsecConfig() []byte { - log.Info("Collecting crowdsec config") - config, err := os.ReadFile(*csConfig.FilePath) - if err != nil { - return []byte(fmt.Sprintf("could not read config file: %s", err)) - } - - r := regexp.MustCompile(`(\s+password:|\s+user:|\s+host:)\s+.*`) - - return r.ReplaceAll(config, []byte("$1 ****REDACTED****")) -} - -func collectCrowdsecProfile() []byte { - log.Info("Collecting crowdsec profile") - config, err := os.ReadFile(csConfig.API.Server.ProfilesPath) - if err != nil { - return []byte(fmt.Sprintf("could not read profile file: %s", err)) - } - return config -} - -func collectAcquisitionConfig() map[string][]byte { - log.Info("Collecting acquisition config") - ret := make(map[string][]byte) - - for _, filename := range csConfig.Crowdsec.AcquisitionFiles { - fileContent, err := os.ReadFile(filename) - if err != nil { - ret[filename] = []byte(fmt.Sprintf("could not read file: %s", err)) - } else { - ret[filename] = fileContent - } - } - - return ret -} - -type cliSupport struct{} - -func NewCLISupport() *cliSupport { - return &cliSupport{} -} - -func (cli cliSupport) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "support [action]", - Short: "Provide commands to help during support", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return nil - }, - } - - cmd.AddCommand(cli.NewDumpCmd()) - - return cmd -} - -func (cli cliSupport) NewDumpCmd() *cobra.Command { - var outFile string - - cmd := &cobra.Command{ - Use: "dump", - Short: "Dump all your configuration to a zip file for easier support", - Long: `Dump the following informations: -- Crowdsec version -- OS version -- Installed collections list -- Installed parsers list -- Installed scenarios list -- Installed postoverflows list -- Installed context list -- Bouncers list -- Machines list -- CAPI status -- LAPI status -- Crowdsec config (sensitive information like username and password are redacted) -- Crowdsec metrics`, - Example: `cscli support dump -cscli support dump -f /tmp/crowdsec-support.zip -`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - Run: func(cmd *cobra.Command, args []string) { - var err error - var skipHub, skipDB, skipCAPI, skipLAPI, skipAgent bool - infos := map[string][]byte{ - SUPPORT_VERSION_PATH: collectVersion(), - SUPPORT_FEATURES_PATH: collectFeatures(), - } - - if outFile == "" { - outFile = "/tmp/crowdsec-support.zip" - } - - dbClient, err = database.NewClient(csConfig.DbConfig) - if err != nil { - log.Warnf("Could not connect to database: %s", err) - skipDB = true - infos[SUPPORT_BOUNCERS_PATH] = []byte(err.Error()) - infos[SUPPORT_AGENTS_PATH] = []byte(err.Error()) - } - - if err := csConfig.LoadAPIServer(true); err != nil { - log.Warnf("could not load LAPI, skipping CAPI check") - skipLAPI = true - infos[SUPPORT_CAPI_STATUS_PATH] = []byte(err.Error()) - } - - if err := csConfig.LoadCrowdsec(); err != nil { - log.Warnf("could not load agent config, skipping crowdsec config check") - skipAgent = true - } - - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - log.Warn("Could not init hub, running on LAPI ? Hub related information will not be collected") - skipHub = true - infos[SUPPORT_PARSERS_PATH] = []byte(err.Error()) - infos[SUPPORT_SCENARIOS_PATH] = []byte(err.Error()) - infos[SUPPORT_POSTOVERFLOWS_PATH] = []byte(err.Error()) - infos[SUPPORT_CONTEXTS_PATH] = []byte(err.Error()) - infos[SUPPORT_COLLECTIONS_PATH] = []byte(err.Error()) - } - - if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil { - log.Warn("no agent credentials found, skipping LAPI connectivity check") - if _, ok := infos[SUPPORT_LAPI_STATUS_PATH]; ok { - infos[SUPPORT_LAPI_STATUS_PATH] = append(infos[SUPPORT_LAPI_STATUS_PATH], []byte("\nNo LAPI credentials found")...) - } - skipLAPI = true - } - - if csConfig.API.Server == nil || csConfig.API.Server.OnlineClient == nil || csConfig.API.Server.OnlineClient.Credentials == nil { - log.Warn("no CAPI credentials found, skipping CAPI connectivity check") - skipCAPI = true - } - - infos[SUPPORT_METRICS_HUMAN_PATH], infos[SUPPORT_METRICS_PROMETHEUS_PATH], err = collectMetrics() - if err != nil { - log.Warnf("could not collect prometheus metrics information: %s", err) - infos[SUPPORT_METRICS_HUMAN_PATH] = []byte(err.Error()) - infos[SUPPORT_METRICS_PROMETHEUS_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_OS_INFO_PATH], err = collectOSInfo() - if err != nil { - log.Warnf("could not collect OS information: %s", err) - infos[SUPPORT_OS_INFO_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_CROWDSEC_CONFIG_PATH] = collectCrowdsecConfig() - - if !skipHub { - infos[SUPPORT_PARSERS_PATH] = collectHubItems(hub, cwhub.PARSERS) - infos[SUPPORT_SCENARIOS_PATH] = collectHubItems(hub, cwhub.SCENARIOS) - infos[SUPPORT_POSTOVERFLOWS_PATH] = collectHubItems(hub, cwhub.POSTOVERFLOWS) - infos[SUPPORT_CONTEXTS_PATH] = collectHubItems(hub, cwhub.POSTOVERFLOWS) - infos[SUPPORT_COLLECTIONS_PATH] = collectHubItems(hub, cwhub.COLLECTIONS) - } - - if !skipDB { - infos[SUPPORT_BOUNCERS_PATH], err = collectBouncers(dbClient) - if err != nil { - log.Warnf("could not collect bouncers information: %s", err) - infos[SUPPORT_BOUNCERS_PATH] = []byte(err.Error()) - } - - infos[SUPPORT_AGENTS_PATH], err = collectAgents(dbClient) - if err != nil { - log.Warnf("could not collect agents information: %s", err) - infos[SUPPORT_AGENTS_PATH] = []byte(err.Error()) - } - } - - if !skipCAPI { - log.Info("Collecting CAPI status") - infos[SUPPORT_CAPI_STATUS_PATH] = collectAPIStatus(csConfig.API.Server.OnlineClient.Credentials.Login, - csConfig.API.Server.OnlineClient.Credentials.Password, - csConfig.API.Server.OnlineClient.Credentials.URL, - CAPIURLPrefix, - hub) - } - - if !skipLAPI { - log.Info("Collection LAPI status") - infos[SUPPORT_LAPI_STATUS_PATH] = collectAPIStatus(csConfig.API.Client.Credentials.Login, - csConfig.API.Client.Credentials.Password, - csConfig.API.Client.Credentials.URL, - LAPIURLPrefix, - hub) - infos[SUPPORT_CROWDSEC_PROFILE_PATH] = collectCrowdsecProfile() - } - - if !skipAgent { - - acquis := collectAcquisitionConfig() - - for filename, content := range acquis { - fname := strings.ReplaceAll(filename, string(filepath.Separator), "___") - infos[SUPPORT_ACQUISITION_CONFIG_BASE_PATH+fname] = content - } - } - - w := bytes.NewBuffer(nil) - zipWriter := zip.NewWriter(w) - - for filename, data := range infos { - fw, err := zipWriter.Create(filename) - if err != nil { - log.Errorf("Could not add zip entry for %s: %s", filename, err) - continue - } - fw.Write([]byte(stripAnsiString(string(data)))) - } - - err = zipWriter.Close() - if err != nil { - log.Fatalf("could not finalize zip file: %s", err) - } - - err = os.WriteFile(outFile, w.Bytes(), 0o600) - if err != nil { - log.Fatalf("could not write zip file to %s: %s", outFile, err) - } - - log.Infof("Written zip file to %s", outFile) - }, - } - - cmd.Flags().StringVarP(&outFile, "outFile", "f", "", "File to dump the information to") - - return cmd -} diff --git a/cmd/crowdsec-cli/tables.go b/cmd/crowdsec-cli/tables.go deleted file mode 100644 index 2c3173d0b0b..00000000000 --- a/cmd/crowdsec-cli/tables.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - - "github.com/aquasecurity/table" - isatty "github.com/mattn/go-isatty" -) - -func shouldWeColorize() bool { - if csConfig.Cscli.Color == "yes" { - return true - } - if csConfig.Cscli.Color == "no" { - return false - } - return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) -} - -func newTable(out io.Writer) *table.Table { - if out == nil { - panic("newTable: out is nil") - } - t := table.New(out) - if shouldWeColorize() { - t.SetLineStyle(table.StyleBrightBlack) - t.SetHeaderStyle(table.StyleItalic) - } - - if shouldWeColorize() { - t.SetDividers(table.UnicodeRoundedDividers) - } else { - t.SetDividers(table.ASCIIDividers) - } - - return t -} - -func newLightTable(out io.Writer) *table.Table { - if out == nil { - panic("newTable: out is nil") - } - t := newTable(out) - t.SetRowLines(false) - t.SetBorderLeft(false) - t.SetBorderRight(false) - // This leaves three spaces between columns: - // left padding, invisible border, right padding - // There is no way to make two spaces without - // a SetColumnLines() method, but it's close enough. - t.SetPadding(1) - - if shouldWeColorize() { - t.SetDividers(table.Dividers{ - ALL: "─", - NES: "─", - NSW: "─", - NEW: "─", - ESW: "─", - NE: "─", - NW: "─", - SW: "─", - ES: "─", - EW: "─", - NS: " ", - }) - } else { - t.SetDividers(table.Dividers{ - ALL: "-", - NES: "-", - NSW: "-", - NEW: "-", - ESW: "-", - NE: "-", - NW: "-", - SW: "-", - ES: "-", - EW: "-", - NS: " ", - }) - } - return t -} - -func renderTableTitle(out io.Writer, title string) { - if out == nil { - panic("renderTableTitle: out is nil") - } - if title == "" { - return - } - fmt.Fprintln(out, title) -} diff --git a/cmd/crowdsec-cli/utils.go b/cmd/crowdsec-cli/utils.go deleted file mode 100644 index b568c6eae3f..00000000000 --- a/cmd/crowdsec-cli/utils.go +++ /dev/null @@ -1,70 +0,0 @@ -package main - -import ( - "fmt" - "net" - "strings" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func printHelp(cmd *cobra.Command) { - if err := cmd.Help(); err != nil { - log.Fatalf("unable to print help(): %s", err) - } -} - -func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *string) error { - /*if a range is provided, change the scope*/ - if *ipRange != "" { - _, _, err := net.ParseCIDR(*ipRange) - if err != nil { - return fmt.Errorf("%s isn't a valid range", *ipRange) - } - } - if *ip != "" { - ipRepr := net.ParseIP(*ip) - if ipRepr == nil { - return fmt.Errorf("%s isn't a valid ip", *ip) - } - } - - //avoid confusion on scope (ip vs Ip and range vs Range) - switch strings.ToLower(*scope) { - case "ip": - *scope = types.Ip - case "range": - *scope = types.Range - case "country": - *scope = types.Country - case "as": - *scope = types.AS - } - return nil -} - -func removeFromSlice(val string, slice []string) []string { - var i int - var value string - - valueFound := false - - // get the index - for i, value = range slice { - if value == val { - valueFound = true - break - } - } - - if valueFound { - slice[i] = slice[len(slice)-1] - slice[len(slice)-1] = "" - slice = slice[:len(slice)-1] - } - - return slice -} diff --git a/cmd/crowdsec-cli/utils_table.go b/cmd/crowdsec-cli/utils_table.go deleted file mode 100644 index b1e4b6950b3..00000000000 --- a/cmd/crowdsec-cli/utils_table.go +++ /dev/null @@ -1,83 +0,0 @@ -package main - -import ( - "fmt" - "io" - "strconv" - - "github.com/aquasecurity/table" - "github.com/enescakir/emoji" - - "github.com/crowdsecurity/crowdsec/pkg/cwhub" -) - -func listHubItemTable(out io.Writer, title string, items []*cwhub.Item) { - t := newLightTable(out) - t.SetHeaders("Name", fmt.Sprintf("%v Status", emoji.Package), "Version", "Local Path") - t.SetHeaderAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft) - - for _, item := range items { - status := fmt.Sprintf("%v %s", item.State.Emoji(), item.State.Text()) - t.AddRow(item.Name, status, item.State.LocalVersion, item.State.LocalPath) - } - renderTableTitle(out, title) - t.Render() -} - -func appsecMetricsTable(out io.Writer, itemName string, metrics map[string]int) { - t := newTable(out) - t.SetHeaders("Inband Hits", "Outband Hits") - - t.AddRow( - strconv.Itoa(metrics["inband_hits"]), - strconv.Itoa(metrics["outband_hits"]), - ) - - renderTableTitle(out, fmt.Sprintf("\n - (AppSec Rule) %s:", itemName)) - t.Render() -} - -func scenarioMetricsTable(out io.Writer, itemName string, metrics map[string]int) { - if metrics["instantiation"] == 0 { - return - } - t := newTable(out) - t.SetHeaders("Current Count", "Overflows", "Instantiated", "Poured", "Expired") - - t.AddRow( - strconv.Itoa(metrics["curr_count"]), - strconv.Itoa(metrics["overflow"]), - strconv.Itoa(metrics["instantiation"]), - strconv.Itoa(metrics["pour"]), - strconv.Itoa(metrics["underflow"]), - ) - - renderTableTitle(out, fmt.Sprintf("\n - (Scenario) %s:", itemName)) - t.Render() -} - -func parserMetricsTable(out io.Writer, itemName string, metrics map[string]map[string]int) { - t := newTable(out) - t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed") - - // don't show table if no hits - showTable := false - - for source, stats := range metrics { - if stats["hits"] > 0 { - t.AddRow( - source, - strconv.Itoa(stats["hits"]), - strconv.Itoa(stats["parsed"]), - strconv.Itoa(stats["unparsed"]), - ) - showTable = true - } - } - - if showTable { - renderTableTitle(out, fmt.Sprintf("\n - (Parser) %s:", itemName)) - t.Render() - } -} diff --git a/cmd/crowdsec-cli/version.go b/cmd/crowdsec-cli/version.go index de36c9be28f..7ec5c459968 100644 --- a/cmd/crowdsec-cli/version.go +++ b/cmd/crowdsec-cli/version.go @@ -1,6 +1,8 @@ package main import ( + "os" + "github.com/spf13/cobra" "github.com/crowdsecurity/crowdsec/pkg/cwversion" @@ -12,14 +14,14 @@ func NewCLIVersion() *cliVersion { return &cliVersion{} } -func (cli cliVersion) NewCommand() *cobra.Command { +func (cliVersion) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "version", Short: "Display version", - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, Run: func(_ *cobra.Command, _ []string) { - cwversion.Show() + _, _ = os.Stdout.WriteString(cwversion.FullString()) }, } diff --git a/cmd/crowdsec/Makefile b/cmd/crowdsec/Makefile index 7425d970ad1..39f807cab88 100644 --- a/cmd/crowdsec/Makefile +++ b/cmd/crowdsec/Makefile @@ -10,13 +10,6 @@ GOTEST = $(GO) test CROWDSEC_BIN = crowdsec$(EXT) # names longer than 15 chars break 'pgrep' -PREFIX ?= "/" -CFG_PREFIX = $(PREFIX)"/etc/crowdsec/config/" -BIN_PREFIX = $(PREFIX)"/usr/local/bin/" -DATA_PREFIX = $(PREFIX)"/var/run/crowdsec/" -PID_DIR = $(PREFIX)"/var/run/" - -SYSTEMD_PATH_FILE = "/etc/systemd/system/crowdsec.service" .PHONY: all all: clean test build @@ -29,41 +22,3 @@ test: clean: @$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR) - -.PHONY: install -install: install-conf install-bin - -.PHONY: install-conf -install-conf: - mkdir -p $(DATA_PREFIX) || exit - (cd ../.. / && find ./data -type f -exec install -Dm 755 "{}" "$(DATA_PREFIX){}" \; && cd ./cmd/crowdsec) || exit - (cd ../../config && find ./patterns -type f -exec install -Dm 755 "{}" "$(CFG_PREFIX){}" \; && cd ../cmd/crowdsec) || exit - mkdir -p "$(CFG_PREFIX)" || exit - mkdir -p "$(CFG_PREFIX)/parsers" || exit - mkdir -p "$(CFG_PREFIX)/scenarios" || exit - mkdir -p "$(CFG_PREFIX)/postoverflows" || exit - mkdir -p "$(CFG_PREFIX)/collections" || exit - mkdir -p "$(CFG_PREFIX)/patterns" || exit - install -v -m 755 -D ../../config/prod.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/dev.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/acquis.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/profiles.yaml "$(CFG_PREFIX)" || exit - install -v -m 755 -D ../../config/api.yaml "$(CFG_PREFIX)" || exit - mkdir -p $(PID_DIR) || exit - PID=$(PID_DIR) DATA=$(DATA_PREFIX)"/data/" CFG=$(CFG_PREFIX) envsubst < ../../config/prod.yaml > $(CFG_PREFIX)"/default.yaml" - -.PHONY: install-bin -install-bin: - install -v -m 755 -D "$(CROWDSEC_BIN)" "$(BIN_PREFIX)/$(CROWDSEC_BIN)" || exit - -.PHONY: systemd -systemd: install - CFG=$(CFG_PREFIX) PID=$(PID_DIR) BIN=$(BIN_PREFIX)"/"$(CROWDSEC_BIN) envsubst < ../../config/crowdsec.service > "$(SYSTEMD_PATH_FILE)" - systemctl daemon-reload - -.PHONY: uninstall -uninstall: - $(RM) $(CFG_PREFIX) $(WIN_IGNORE_ERR) - $(RM) $(DATA_PREFIX) $(WIN_IGNORE_ERR) - $(RM) "$(BIN_PREFIX)/$(CROWDSEC_BIN)" $(WIN_IGNORE_ERR) - $(RM) "$(SYSTEMD_PATH_FILE)" $(WIN_IGNORE_ERR) diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index a1e933cba89..ccb0acf0209 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,11 +1,12 @@ package main import ( + "context" + "errors" "fmt" "runtime" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/trace" @@ -14,12 +15,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { +func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) { if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil { log.Info("push and pull to Central API disabled") } - apiServer, err := apiserver.NewServer(cConfig.API.Server) + apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server) if err != nil { return nil, fmt.Errorf("unable to run local API: %w", err) } @@ -39,7 +40,7 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined") } - err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) + err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) if err != nil { return nil, fmt.Errorf("unable to run plugin broker: %w", err) } @@ -56,12 +57,16 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return apiServer, nil } -func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { +func serveAPIServer(apiServer *apiserver.APIServer) { + apiReady := make(chan bool, 1) + apiTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveAPIServer") + go func() { defer trace.CatchPanic("crowdsec/runAPIServer") log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) + if err := apiServer.Run(apiReady); err != nil { log.Fatal(err) } @@ -75,11 +80,10 @@ func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) { <-apiTomb.Dying() // lock until go routine is dying pluginTomb.Kill(nil) log.Infof("serve: shutting down api server") - if err := apiServer.Shutdown(); err != nil { - return err - } - return nil + + return apiServer.Shutdown() }) + <-apiReady } func hasPlugins(profiles []*csconfig.ProfileCfg) bool { @@ -88,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool { return true } } + return false } diff --git a/cmd/crowdsec/appsec.go b/cmd/crowdsec/appsec.go new file mode 100644 index 00000000000..cb02b137dcd --- /dev/null +++ b/cmd/crowdsec/appsec.go @@ -0,0 +1,18 @@ +// +build !no_datasource_appsec + +package main + +import ( + "fmt" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + if err := appsec.LoadAppsecRules(hub); err != nil { + return fmt.Errorf("while loading appsec rules: %w", err) + } + + return nil +} diff --git a/cmd/crowdsec/appsec_stub.go b/cmd/crowdsec/appsec_stub.go new file mode 100644 index 00000000000..4a65b32a9ad --- /dev/null +++ b/cmd/crowdsec/appsec_stub.go @@ -0,0 +1,11 @@ +//go:build no_datasource_appsec + +package main + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + return nil +} diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 774b9d381cf..c44d71d2093 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -1,241 +1,228 @@ package main import ( + "context" "fmt" "os" - "path/filepath" "sync" "time" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition" - "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/parser" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, error) { +// initCrowdsec prepares the log processor service +func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, []acquisition.DataSource, error) { var err error if err = alertcontext.LoadConsoleContext(cConfig, hub); err != nil { - return nil, fmt.Errorf("while loading context: %w", err) + return nil, nil, fmt.Errorf("while loading context: %w", err) + } + + err = exprhelpers.GeoIPInit(hub.GetDataDir()) + if err != nil { + // GeoIP databases are not mandatory, do not make crowdsec fail if they are not present + log.Warnf("unable to initialize GeoIP: %s", err) } // Start loading configs csParsers := parser.NewParsers(hub) if csParsers, err = parser.LoadParsers(cConfig, csParsers); err != nil { - return nil, fmt.Errorf("while loading parsers: %w", err) + return nil, nil, fmt.Errorf("while loading parsers: %w", err) } - if err := LoadBuckets(cConfig, hub); err != nil { - return nil, fmt.Errorf("while loading scenarios: %w", err) + if err = LoadBuckets(cConfig, hub); err != nil { + return nil, nil, fmt.Errorf("while loading scenarios: %w", err) } - if err := appsec.LoadAppsecRules(hub); err != nil { - return nil, fmt.Errorf("while loading appsec rules: %w", err) + // can be nerfed by a build flag + if err = LoadAppsecRules(hub); err != nil { + return nil, nil, err } - if err := LoadAcquisition(cConfig); err != nil { - return nil, fmt.Errorf("while loading acquisition config: %w", err) + datasources, err := LoadAcquisition(cConfig) + if err != nil { + return nil, nil, fmt.Errorf("while loading acquisition config: %w", err) } - return csParsers, nil + return csParsers, datasources, nil } -func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.Hub) error { +// runCrowdsec starts the log processor service +func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.Hub, datasources []acquisition.DataSource) error { inputEventChan = make(chan types.Event) inputLineChan = make(chan types.Event) - //start go-routines for parsing, buckets pour and outputs. + // start go-routines for parsing, buckets pour and outputs. parserWg := &sync.WaitGroup{} + parsersTomb.Go(func() error { parserWg.Add(1) - for i := 0; i < cConfig.Crowdsec.ParserRoutinesCount; i++ { + + for range cConfig.Crowdsec.ParserRoutinesCount { parsersTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runParse") - if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { //this error will never happen as parser.Parse is not able to return errors - log.Fatalf("starting parse error : %s", err) + + if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { + // this error will never happen as parser.Parse is not able to return errors return err } + return nil }) } + parserWg.Done() + return nil }) parserWg.Wait() bucketWg := &sync.WaitGroup{} + bucketsTomb.Go(func() error { bucketWg.Add(1) - /*restore previous state as well if present*/ + // restore previous state as well if present if cConfig.Crowdsec.BucketStateFile != "" { log.Warningf("Restoring buckets state from %s", cConfig.Crowdsec.BucketStateFile) + if err := leaky.LoadBucketsState(cConfig.Crowdsec.BucketStateFile, buckets, holders); err != nil { - return fmt.Errorf("unable to restore buckets : %s", err) + return fmt.Errorf("unable to restore buckets: %w", err) } } - for i := 0; i < cConfig.Crowdsec.BucketsRoutinesCount; i++ { + for range cConfig.Crowdsec.BucketsRoutinesCount { bucketsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runPour") - if err := runPour(inputEventChan, holders, buckets, cConfig); err != nil { - log.Fatalf("starting pour error : %s", err) - return err - } - return nil + + return runPour(inputEventChan, holders, buckets, cConfig) }) } + bucketWg.Done() + return nil }) bucketWg.Wait() + apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + if err != nil { + return err + } + + log.Debugf("Starting HeartBeat service") + apiClient.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) + outputWg := &sync.WaitGroup{} + outputsTomb.Go(func() error { outputWg.Add(1) - for i := 0; i < cConfig.Crowdsec.OutputRoutinesCount; i++ { + + for range cConfig.Crowdsec.OutputRoutinesCount { outputsTomb.Go(func() error { defer trace.CatchPanic("crowdsec/runOutput") - if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, *cConfig.API.Client.Credentials, hub); err != nil { - log.Fatalf("starting outputs error : %s", err) - return err - } - return nil + + return runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, apiClient) }) } + outputWg.Done() + return nil }) outputWg.Wait() + mp := NewMetricsProvider( + apiClient, + lpMetricsDefaultInterval, + log.WithField("service", "lpmetrics"), + []string{}, + datasources, + hub, + ) + + lpMetricsTomb.Go(func() error { + return mp.Run(context.Background(), &lpMetricsTomb) + }) + if cConfig.Prometheus != nil && cConfig.Prometheus.Enabled { aggregated := false - if cConfig.Prometheus.Level == "aggregated" { + if cConfig.Prometheus.Level == configuration.CFG_METRICS_AGGREGATE { aggregated = true } + if err := acquisition.GetMetrics(dataSources, aggregated); err != nil { return fmt.Errorf("while fetching prometheus metrics for datasources: %w", err) } - } + log.Info("Starting processing data") - if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { - log.Fatalf("starting acquisition error : %s", err) - return err + if err := acquisition.StartAcquisition(context.TODO(), dataSources, inputLineChan, &acquisTomb); err != nil { + return fmt.Errorf("starting acquisition error: %w", err) } return nil } -func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub.Hub, agentReady chan bool) { +// serveCrowdsec wraps the log processor service +func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub.Hub, datasources []acquisition.DataSource, agentReady chan bool) { crowdsecTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveCrowdsec") + go func() { defer trace.CatchPanic("crowdsec/runCrowdsec") // this logs every time, even at config reload log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) agentReady <- true - if err := runCrowdsec(cConfig, parsers, hub); err != nil { + + if err := runCrowdsec(cConfig, parsers, hub, datasources); err != nil { log.Fatalf("unable to start crowdsec routines: %s", err) } }() - /*we should stop in two cases : + /* we should stop in two cases : - crowdsecTomb has been Killed() : it might be shutdown or reload, so stop - acquisTomb is dead, it means that we were in "cat" mode and files are done reading, quit */ waitOnTomb() log.Debugf("Shutting down crowdsec routines") + if err := ShutdownCrowdsecRoutines(); err != nil { - log.Fatalf("unable to shutdown crowdsec routines: %s", err) + return fmt.Errorf("unable to shutdown crowdsec routines: %w", err) } + log.Debugf("everything is dead, return crowdsecTomb") + if dumpStates { - dumpParserState() - dumpOverflowState() - dumpBucketsPour() + if err := dumpAllStates(); err != nil { + log.Fatal(err) + } os.Exit(0) } + return nil }) } -func dumpBucketsPour() { - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucketpour-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(leaky.BucketPourCache) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - -func dumpParserState() { - - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "parser-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(parser.StageParseCache) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - -func dumpOverflowState() { - - fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - log.Fatalf("open: %s", err) - } - out, err := yaml.Marshal(bucketOverflows) - if err != nil { - log.Fatalf("marshal: %s", err) - } - b, err := fd.Write(out) - if err != nil { - log.Fatalf("write: %s", err) - } - log.Tracef("wrote %d bytes", b) - if err := fd.Close(); err != nil { - log.Fatalf(" close: %s", err) - } -} - func waitOnTomb() { for { select { case <-acquisTomb.Dead(): - /*if it's acquisition dying it means that we were in "cat" mode. + /* if it's acquisition dying it means that we were in "cat" mode. while shutting down, we need to give time for all buckets to process in flight data*/ - log.Warning("Acquisition is finished, shutting down") + log.Info("Acquisition is finished, shutting down") /* While it might make sense to want to shut-down parser/buckets/etc. as soon as acquisition is finished, we might have some pending buckets: buckets that overflowed, but whose LeakRoutine are still alive because they diff --git a/cmd/crowdsec/dump.go b/cmd/crowdsec/dump.go new file mode 100644 index 00000000000..33c65878b11 --- /dev/null +++ b/cmd/crowdsec/dump.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" + "github.com/crowdsecurity/crowdsec/pkg/parser" +) + +func dumpAllStates() error { + log.Debugf("Dumping parser+bucket states to %s", parser.DumpFolder) + + if err := dumpState( + filepath.Join(parser.DumpFolder, "parser-dump.yaml"), + parser.StageParseCache, + ); err != nil { + return fmt.Errorf("while dumping parser state: %w", err) + } + + if err := dumpState( + filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), + bucketOverflows, + ); err != nil { + return fmt.Errorf("while dumping bucket overflow state: %w", err) + } + + if err := dumpState( + filepath.Join(parser.DumpFolder, "bucketpour-dump.yaml"), + leaky.BucketPourCache, + ); err != nil { + return fmt.Errorf("while dumping bucket pour state: %w", err) + } + + return nil +} + +func dumpState(destPath string, obj any) error { + dir := filepath.Dir(destPath) + + err := os.MkdirAll(dir, 0o755) + if err != nil { + return err + } + + out, err := yaml.Marshal(obj) + if err != nil { + return err + } + + return os.WriteFile(destPath, out, 0o666) +} diff --git a/cmd/crowdsec/fatalhook.go b/cmd/crowdsec/fatalhook.go new file mode 100644 index 00000000000..84a57406a21 --- /dev/null +++ b/cmd/crowdsec/fatalhook.go @@ -0,0 +1,28 @@ +package main + +import ( + "io" + + log "github.com/sirupsen/logrus" +) + +// FatalHook is used to log fatal messages to stderr when the rest goes to a file +type FatalHook struct { + Writer io.Writer + LogLevels []log.Level +} + +func (hook *FatalHook) Fire(entry *log.Entry) error { + line, err := entry.String() + if err != nil { + return err + } + + _, err = hook.Writer.Write([]byte(line)) + + return err +} + +func (hook *FatalHook) Levels() []log.Level { + return hook.LogLevels +} diff --git a/cmd/crowdsec/hook.go b/cmd/crowdsec/hook.go deleted file mode 100644 index 28515d9e474..00000000000 --- a/cmd/crowdsec/hook.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "io" - "os" - - log "github.com/sirupsen/logrus" -) - -type ConditionalHook struct { - Writer io.Writer - LogLevels []log.Level - Enabled bool -} - -func (hook *ConditionalHook) Fire(entry *log.Entry) error { - if hook.Enabled { - line, err := entry.String() - if err != nil { - return err - } - - _, err = hook.Writer.Write([]byte(line)) - - return err - } - - return nil -} - -func (hook *ConditionalHook) Levels() []log.Level { - return hook.LogLevels -} - -// The primal logging hook is set up before parsing config.yaml. -// Once config.yaml is parsed, the primal hook is disabled if the -// configured logger is writing to stderr. Otherwise it's used to -// report fatal errors and panics to stderr in addition to the log file. -var primalHook = &ConditionalHook{ - Writer: os.Stderr, - LogLevels: []log.Level{log.FatalLevel, log.PanicLevel}, - Enabled: true, -} diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go new file mode 100644 index 00000000000..eed517f9df9 --- /dev/null +++ b/cmd/crowdsec/lapiclient.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/go-openapi/strfmt" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { + apiURL, err := url.Parse(credentials.URL) + if err != nil { + return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) + } + + papiURL, err := url.Parse(credentials.PapiURL) + if err != nil { + return nil, fmt.Errorf("parsing polling api url ('%s'): %w", credentials.PapiURL, err) + } + + password := strfmt.Password(credentials.Password) + + itemsForAPI := hub.GetInstalledListForAPI() + + client, err := apiclient.NewClient(&apiclient.Config{ + MachineID: credentials.Login, + Password: password, + Scenarios: itemsForAPI, + URL: apiURL, + PapiURL: papiURL, + VersionPrefix: "v1", + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil + }, + }) + if err != nil { + return nil, fmt.Errorf("new client api: %w", err) + } + + authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + MachineID: &credentials.Login, + Password: &password, + Scenarios: itemsForAPI, + }) + if err != nil { + return nil, fmt.Errorf("authenticate watcher (%s): %w", credentials.Login, err) + } + + var expiration time.Time + if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return nil, fmt.Errorf("unable to parse jwt expiration: %w", err) + } + + client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + client.GetClient().Transport.(*apiclient.JWTTransport).Expiration = expiration + + return client, nil +} diff --git a/cmd/crowdsec/lpmetrics.go b/cmd/crowdsec/lpmetrics.go new file mode 100644 index 00000000000..24842851294 --- /dev/null +++ b/cmd/crowdsec/lpmetrics.go @@ -0,0 +1,180 @@ +package main + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/fflag" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +const lpMetricsDefaultInterval = 30 * time.Minute + +// MetricsProvider collects metrics from the LP and sends them to the LAPI +type MetricsProvider struct { + apic *apiclient.ApiClient + interval time.Duration + static staticMetrics + logger *logrus.Entry +} + +type staticMetrics struct { + osName string + osVersion string + startupTS int64 + featureFlags []string + consoleOptions []string + datasourceMap map[string]int64 + hubState models.HubItems +} + +func getHubState(hub *cwhub.Hub) models.HubItems { + ret := models.HubItems{} + + for _, itemType := range cwhub.ItemTypes { + ret[itemType] = []models.HubItem{} + + for _, item := range hub.GetInstalledByType(itemType, true) { + status := "official" + if item.State.IsLocal() { + status = "custom" + } + if item.State.Tainted { + status = "tainted" + } + ret[itemType] = append(ret[itemType], models.HubItem{ + Name: item.Name, + Status: status, + Version: item.Version, + }) + } + } + + return ret +} + +// newStaticMetrics is called when the process starts, or reloads the configuration +func newStaticMetrics(consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub) staticMetrics { + datasourceMap := map[string]int64{} + + for _, ds := range datasources { + datasourceMap[ds.GetName()] += 1 + } + + osName, osVersion := version.DetectOS() + + return staticMetrics{ + osName: osName, + osVersion: osVersion, + startupTS: time.Now().UTC().Unix(), + featureFlags: fflag.Crowdsec.GetEnabledFeatures(), + consoleOptions: consoleOptions, + datasourceMap: datasourceMap, + hubState: getHubState(hub), + } +} + +func NewMetricsProvider(apic *apiclient.ApiClient, interval time.Duration, logger *logrus.Entry, + consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub, +) *MetricsProvider { + return &MetricsProvider{ + apic: apic, + interval: interval, + logger: logger, + static: newStaticMetrics(consoleOptions, datasources, hub), + } +} + +func (m *MetricsProvider) metricsPayload() *models.AllMetrics { + os := &models.OSversion{ + Name: ptr.Of(m.static.osName), + Version: ptr.Of(m.static.osVersion), + } + + base := models.BaseMetrics{ + UtcStartupTimestamp: ptr.Of(m.static.startupTS), + Os: os, + Version: ptr.Of(version.String()), + FeatureFlags: m.static.featureFlags, + Metrics: make([]*models.DetailedMetrics, 0), + } + + met := &models.LogProcessorsMetrics{ + BaseMetrics: base, + Datasources: m.static.datasourceMap, + HubItems: m.static.hubState, + } + + met.Metrics = append(met.Metrics, &models.DetailedMetrics{ + Meta: &models.MetricsMeta{ + UtcNowTimestamp: ptr.Of(time.Now().Unix()), + WindowSizeSeconds: ptr.Of(int64(m.interval.Seconds())), + }, + Items: make([]*models.MetricsDetailItem, 0), + }) + + return &models.AllMetrics{ + LogProcessors: []*models.LogProcessorsMetrics{met}, + } +} + +func (m *MetricsProvider) Run(ctx context.Context, myTomb *tomb.Tomb) error { + defer trace.CatchPanic("crowdsec/MetricsProvider.Run") + + if m.interval == time.Duration(0) { + return nil + } + + met := m.metricsPayload() + + ticker := time.NewTicker(1) // Send on start + + for { + select { + case <-ticker.C: + ctxTime, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + _, resp, err := m.apic.UsageMetrics.Add(ctxTime, met) + switch { + case errors.Is(err, context.DeadlineExceeded): + m.logger.Warnf("timeout sending lp metrics") + ticker.Reset(m.interval) + continue + case err != nil && resp != nil && resp.Response.StatusCode == http.StatusNotFound: + m.logger.Warnf("metrics endpoint not found, older LAPI?") + ticker.Reset(m.interval) + continue + case err != nil: + m.logger.Warnf("failed to send lp metrics: %s", err) + ticker.Reset(m.interval) + continue + } + + if resp.Response.StatusCode != http.StatusCreated { + m.logger.Warnf("failed to send lp metrics: %s", resp.Response.Status) + ticker.Reset(m.interval) + continue + } + + ticker.Reset(m.interval) + + m.logger.Tracef("lp usage metrics sent") + case <-myTomb.Dying(): + ticker.Stop() + return nil + } + } +} diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 2040141bb3e..6d8ca24c335 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -1,19 +1,22 @@ package main import ( + "errors" "flag" "fmt" _ "net/http/pprof" "os" + "path/filepath" "runtime" "runtime/pprof" "strings" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/crowdsec/pkg/acquisition" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -26,28 +29,29 @@ import ( ) var ( - /*tombs for the parser, buckets and outputs.*/ - acquisTomb tomb.Tomb - parsersTomb tomb.Tomb - bucketsTomb tomb.Tomb - outputsTomb tomb.Tomb - apiTomb tomb.Tomb - crowdsecTomb tomb.Tomb - pluginTomb tomb.Tomb + // tombs for the parser, buckets and outputs. + acquisTomb tomb.Tomb + parsersTomb tomb.Tomb + bucketsTomb tomb.Tomb + outputsTomb tomb.Tomb + apiTomb tomb.Tomb + crowdsecTomb tomb.Tomb + pluginTomb tomb.Tomb + lpMetricsTomb tomb.Tomb flags *Flags - /*the state of acquisition*/ + // the state of acquisition dataSources []acquisition.DataSource - /*the state of the buckets*/ + // the state of the buckets holders []leakybucket.BucketFactory buckets *leakybucket.Buckets inputLineChan chan types.Event inputEventChan chan types.Event outputEventChan chan types.Event // the buckets init returns its own chan that is used for multiplexing - /*settings*/ - lastProcessedItem time.Time /*keep track of last item timestamp in time-machine. it is used to GC buckets when we dump them.*/ + // settings + lastProcessedItem time.Time // keep track of last item timestamp in time-machine. it is used to GC buckets when we dump them. pluginBroker csplugin.PluginBroker ) @@ -72,7 +76,11 @@ type Flags struct { DisableCAPI bool Transform string OrderEvent bool - CpuProfile string + CPUProfile string +} + +func (f *Flags) haveTimeMachine() bool { + return f.OneShotDSN != "" } type labelsMap map[string]string @@ -83,19 +91,17 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { files []string ) - for _, hubScenarioItem := range hub.GetItemMap(cwhub.SCENARIOS) { - if hubScenarioItem.State.Installed { - files = append(files, hubScenarioItem.State.LocalPath) - } + for _, hubScenarioItem := range hub.GetInstalledByType(cwhub.SCENARIOS, false) { + files = append(files, hubScenarioItem.State.LocalPath) } buckets = leakybucket.NewBuckets() log.Infof("Loading %d scenario files", len(files)) - holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, files, &bucketsTomb, buckets, flags.OrderEvent) + holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, files, &bucketsTomb, buckets, flags.OrderEvent) if err != nil { - return fmt.Errorf("scenario loading failed: %v", err) + return fmt.Errorf("scenario loading failed: %w", err) } if cConfig.Prometheus != nil && cConfig.Prometheus.Enabled { @@ -107,7 +113,7 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { return nil } -func LoadAcquisition(cConfig *csconfig.Config) error { +func LoadAcquisition(cConfig *csconfig.Config) ([]acquisition.DataSource, error) { var err error if flags.SingleFileType != "" && flags.OneShotDSN != "" { @@ -116,20 +122,20 @@ func LoadAcquisition(cConfig *csconfig.Config) error { dataSources, err = acquisition.LoadAcquisitionFromDSN(flags.OneShotDSN, flags.Labels, flags.Transform) if err != nil { - return errors.Wrapf(err, "failed to configure datasource for %s", flags.OneShotDSN) + return nil, fmt.Errorf("failed to configure datasource for %s: %w", flags.OneShotDSN, err) } } else { - dataSources, err = acquisition.LoadAcquisitionFromFile(cConfig.Crowdsec) + dataSources, err = acquisition.LoadAcquisitionFromFile(cConfig.Crowdsec, cConfig.Prometheus) if err != nil { - return err + return nil, err } } if len(dataSources) == 0 { - return fmt.Errorf("no datasource enabled") + return nil, errors.New("no datasource enabled") } - return nil + return dataSources, nil } var ( @@ -181,7 +187,7 @@ func (f *Flags) Parse() { } flag.StringVar(&dumpFolder, "dump-data", "", "dump parsers/buckets raw outputs") - flag.StringVar(&f.CpuProfile, "cpu-profile", "", "write cpu profile to file") + flag.StringVar(&f.CPUProfile, "cpu-profile", "", "write cpu profile to file") flag.Parse() } @@ -226,6 +232,10 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo return nil, fmt.Errorf("while loading configuration file: %w", err) } + if err := trace.Init(filepath.Join(cConfig.ConfigPaths.DataDir, "trace")); err != nil { + return nil, fmt.Errorf("while setting up trace directory: %w", err) + } + cConfig.Common.LogLevel = newLogLevel(cConfig.Common.LogLevel, flags) if dumpFolder != "" { @@ -249,7 +259,12 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo return nil, err } - primalHook.Enabled = (cConfig.Common.LogMedia != "stdout") + if cConfig.Common.LogMedia != "stdout" { + log.AddHook(&FatalHook{ + Writer: os.Stderr, + LogLevels: []log.Level{log.FatalLevel, log.PanicLevel}, + }) + } if err := csconfig.LoadFeatureFlagsFile(configFile, log.StandardLogger()); err != nil { return nil, err @@ -272,7 +287,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo } if cConfig.DisableAPI && cConfig.DisableAgent { - return nil, errors.New("You must run at least the API Server or crowdsec") + return nil, errors.New("you must run at least the API Server or crowdsec") } if flags.OneShotDSN != "" && flags.SingleFileType == "" { @@ -291,7 +306,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo if cConfig.API != nil && cConfig.API.Server != nil { cConfig.API.Server.OnlineClient = nil } - /*if the api is disabled as well, just read file and exit, don't daemonize*/ + // if the api is disabled as well, just read file and exit, don't daemonize if cConfig.DisableAPI { cConfig.Common.Daemonize = false } @@ -323,7 +338,9 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo var crowdsecT0 time.Time func main() { - log.AddHook(primalHook) + // The initial log level is INFO, even if the user provided an -error or -warning flag + // because we need feature flags before parsing cli flags + log.SetFormatter(&log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true}) if err := fflag.RegisterAllFeatures(); err != nil { log.Fatalf("failed to register features: %s", err) @@ -351,20 +368,23 @@ func main() { } if flags.PrintVersion { - cwversion.Show() + os.Stdout.WriteString(cwversion.FullString()) os.Exit(0) } - if flags.CpuProfile != "" { - f, err := os.Create(flags.CpuProfile) + if flags.CPUProfile != "" { + f, err := os.Create(flags.CPUProfile) if err != nil { log.Fatalf("could not create CPU profile: %s", err) } - log.Infof("CPU profile will be written to %s", flags.CpuProfile) + + log.Infof("CPU profile will be written to %s", flags.CPUProfile) + if err := pprof.StartCPUProfile(f); err != nil { f.Close() log.Fatalf("could not start CPU profile: %s", err) } + defer f.Close() defer pprof.StopCPUProfile() } diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index ca893872edb..ff280fc3512 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -3,7 +3,6 @@ package main import ( "fmt" "net/http" - "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -12,7 +11,8 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/version" - v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" "github.com/crowdsecurity/crowdsec/pkg/cache" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -21,7 +21,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/parser" ) -/*prometheus*/ +// Prometheus + var globalParserHits = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_total", @@ -29,6 +30,7 @@ var globalParserHits = prometheus.NewCounterVec( }, []string{"source", "type"}, ) + var globalParserHitsOk = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_ok_total", @@ -36,6 +38,7 @@ var globalParserHitsOk = prometheus.NewCounterVec( }, []string{"source", "type"}, ) + var globalParserHitsKo = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_parser_hits_ko_total", @@ -102,25 +105,31 @@ var globalPourHistogram = prometheus.NewHistogramVec( func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - //update cache metrics (stash) + // catch panics here because they are not handled by servePrometheus + defer trace.CatchPanic("crowdsec/computeDynamicMetrics") + // update cache metrics (stash) cache.UpdateCacheMetrics() - //update cache metrics (regexp) + // update cache metrics (regexp) exprhelpers.UpdateRegexpCacheMetrics() - //decision metrics are only relevant for LAPI + // decision metrics are only relevant for LAPI if dbClient == nil { next.ServeHTTP(w, r) return } - decisionsFilters := make(map[string][]string, 0) - decisions, err := dbClient.QueryDecisionCountByScenario(decisionsFilters) + ctx := r.Context() + + decisions, err := dbClient.QueryDecisionCountByScenario(ctx) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) + return } + globalActiveDecisions.Reset() + for _, d := range decisions { globalActiveDecisions.With(prometheus.Labels{"reason": d.Scenario, "origin": d.Origin, "action": d.Type}).Set(float64(d.Count)) } @@ -131,11 +140,11 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha "include_capi": {"false"}, } - alerts, err := dbClient.AlertsCountPerScenario(alertsFilter) - + alerts, err := dbClient.AlertsCountPerScenario(ctx, alertsFilter) if err != nil { log.Errorf("Error querying alerts for metrics: %v", err) next.ServeHTTP(w, r) + return } @@ -154,14 +163,14 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { // Registering prometheus // If in aggregated mode, do not register events associated with a source, to keep the cardinality low - if config.Level == "aggregated" { + if config.Level == configuration.CFG_METRICS_AGGREGATE { log.Infof("Loading aggregated prometheus collectors") prometheus.MustRegister(globalParserHits, globalParserHitsOk, globalParserHitsKo, globalCsInfo, globalParsingHistogram, globalPourHistogram, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, v1.LapiRouteHits, leaky.BucketsCurrentCount, - cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, + cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, parser.NodesWlHitsOk, parser.NodesWlHits, ) } else { log.Infof("Loading prometheus collectors") @@ -170,14 +179,15 @@ func registerPrometheus(config *csconfig.PrometheusCfg) { globalCsInfo, globalParsingHistogram, globalPourHistogram, v1.LapiRouteHits, v1.LapiMachineHits, v1.LapiBouncerHits, v1.LapiNilDecisions, v1.LapiNonNilDecisions, v1.LapiResponseTime, leaky.BucketsPour, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, leaky.BucketsCurrentCount, - globalActiveDecisions, globalAlerts, + globalActiveDecisions, globalAlerts, parser.NodesWlHitsOk, parser.NodesWlHits, cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, ) - } } -func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, apiReady chan bool, agentReady chan bool) { +func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, agentReady chan bool) { + <-agentReady + if !config.Enabled { return } @@ -185,10 +195,11 @@ func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, defer trace.CatchPanic("crowdsec/servePrometheus") http.Handle("/metrics", computeDynamicMetrics(promhttp.Handler(), dbClient)) - <-apiReady - <-agentReady - log.Debugf("serving metrics after %s ms", time.Since(crowdsecT0)) + if err := http.ListenAndServe(fmt.Sprintf("%s:%d", config.ListenAddr, config.ListenPort), nil); err != nil { - log.Warningf("prometheus: %s", err) + // in time machine, we most likely have the LAPI using the port + if !flags.haveTimeMachine() { + log.Warningf("prometheus: %s", err) + } } } diff --git a/cmd/crowdsec/output.go b/cmd/crowdsec/output.go index ad53ce4c827..6f507fdcd6f 100644 --- a/cmd/crowdsec/output.go +++ b/cmd/crowdsec/output.go @@ -3,18 +3,12 @@ package main import ( "context" "fmt" - "net/url" "sync" "time" - "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/version" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/parser" @@ -22,7 +16,6 @@ import ( ) func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { - var dedupCache []*models.Alert for idx, alert := range alerts { @@ -32,113 +25,51 @@ func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) { dedupCache = append(dedupCache, alert.Alert) continue } - for k, src := range alert.Sources { - refsrc := *alert.Alert //copy + + for k := range alert.Sources { + refsrc := *alert.Alert // copy + log.Tracef("source[%s]", k) + + src := alert.Sources[k] refsrc.Source = &src dedupCache = append(dedupCache, &refsrc) } } + if len(dedupCache) != len(alerts) { log.Tracef("went from %d to %d alerts", len(alerts), len(dedupCache)) } + return dedupCache, nil } func PushAlerts(alerts []types.RuntimeAlert, client *apiclient.ApiClient) error { ctx := context.Background() - alertsToPush, err := dedupAlerts(alerts) + alertsToPush, err := dedupAlerts(alerts) if err != nil { return fmt.Errorf("failed to transform alerts for api: %w", err) } + _, _, err = client.Alerts.Add(ctx, alertsToPush) if err != nil { return fmt.Errorf("failed sending alert to LAPI: %w", err) } + return nil } var bucketOverflows []types.Event -func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, - postOverflowCTX parser.UnixParserCtx, postOverflowNodes []parser.Node, - apiConfig csconfig.ApiCredentialsCfg, hub *cwhub.Hub) error { +func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, postOverflowCTX parser.UnixParserCtx, + postOverflowNodes []parser.Node, client *apiclient.ApiClient) error { + var ( + cache []types.RuntimeAlert + cacheMutex sync.Mutex + ) - var err error ticker := time.NewTicker(1 * time.Second) - - var cache []types.RuntimeAlert - var cacheMutex sync.Mutex - - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("loading list of installed hub scenarios: %w", err) - } - - appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) - if err != nil { - return fmt.Errorf("loading list of installed hub appsec rules: %w", err) - } - - installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules)) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...) - - apiURL, err := url.Parse(apiConfig.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err) - } - papiURL, err := url.Parse(apiConfig.PapiURL) - if err != nil { - return fmt.Errorf("parsing polling api url ('%s'): %w", apiConfig.PapiURL, err) - } - password := strfmt.Password(apiConfig.Password) - - Client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: apiConfig.Login, - Password: password, - Scenarios: installedScenariosAndAppsecRules, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - PapiURL: papiURL, - VersionPrefix: "v1", - UpdateScenario: func() ([]string, error) { - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return nil, err - } - appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) - if err != nil { - return nil, err - } - ret := make([]string, 0, len(scenarios)+len(appsecRules)) - ret = append(ret, scenarios...) - ret = append(ret, appsecRules...) - return ret, nil - }, - }) - if err != nil { - return fmt.Errorf("new client api: %w", err) - } - authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - MachineID: &apiConfig.Login, - Password: &password, - Scenarios: installedScenariosAndAppsecRules, - }) - if err != nil { - return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err) - } - - if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { - return fmt.Errorf("unable to parse jwt expiration: %w", err) - } - - Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token - - //start the heartbeat service - log.Debugf("Starting HeartBeat service") - Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) LOOP: for { select { @@ -149,9 +80,9 @@ LOOP: newcache := make([]types.RuntimeAlert, 0) cache = newcache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing to api : %s", err) - //just push back the events to the queue + // just push back the events to the queue cacheMutex.Lock() cache = append(cache, cachecopy...) cacheMutex.Unlock() @@ -162,10 +93,11 @@ LOOP: cacheMutex.Lock() cachecopy := cache cacheMutex.Unlock() - if err := PushAlerts(cachecopy, Client); err != nil { + if err := PushAlerts(cachecopy, client); err != nil { log.Errorf("while pushing leftovers to api : %s", err) } } + break LOOP case event := <-overflow: /*if alert is empty and mapKey is present, the overflow is just to cleanup bucket*/ @@ -176,11 +108,11 @@ LOOP: /* process post overflow parser nodes */ event, err := parser.Parse(postOverflowCTX, event, postOverflowNodes) if err != nil { - return fmt.Errorf("postoverflow failed : %s", err) + return fmt.Errorf("postoverflow failed: %w", err) } log.Printf("%s", *event.Overflow.Alert.Message) - //if the Alert is nil, it's to signal bucket is ready for GC, don't track this - //dump after postoveflow processing to avoid missing whitelist info + // if the Alert is nil, it's to signal bucket is ready for GC, don't track this + // dump after postoveflow processing to avoid missing whitelist info if dumpStates && event.Overflow.Alert != nil { if bucketOverflows == nil { bucketOverflows = make([]types.Event, 0) @@ -206,6 +138,6 @@ LOOP: } ticker.Stop() - return nil + return nil } diff --git a/cmd/crowdsec/parse.go b/cmd/crowdsec/parse.go index c62eeb5869d..26eae66be2b 100644 --- a/cmd/crowdsec/parse.go +++ b/cmd/crowdsec/parse.go @@ -11,13 +11,11 @@ import ( ) func runParse(input chan types.Event, output chan types.Event, parserCTX parser.UnixParserCtx, nodes []parser.Node) error { - -LOOP: for { select { case <-parsersTomb.Dying(): log.Infof("Killing parser routines") - break LOOP + return nil case event := <-input: if !event.Process { continue @@ -39,7 +37,7 @@ LOOP: /* parse the log using magic */ parsed, err := parser.Parse(parserCTX, event, nodes) if err != nil { - log.Errorf("failed parsing : %v\n", err) + log.Errorf("failed parsing: %v", err) } elapsed := time.Since(startParsing) globalParsingHistogram.With(prometheus.Labels{"source": event.Line.Src, "type": event.Line.Module}).Observe(elapsed.Seconds()) @@ -56,5 +54,4 @@ LOOP: output <- parsed } } - return nil } diff --git a/cmd/crowdsec/pour.go b/cmd/crowdsec/pour.go index 3f717e3975d..2fc7d7e42c9 100644 --- a/cmd/crowdsec/pour.go +++ b/cmd/crowdsec/pour.go @@ -4,57 +4,64 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" ) func runPour(input chan types.Event, holders []leaky.BucketFactory, buckets *leaky.Buckets, cConfig *csconfig.Config) error { count := 0 + for { - //bucket is now ready + // bucket is now ready select { case <-bucketsTomb.Dying(): log.Infof("Bucket routine exiting") return nil case parsed := <-input: startTime := time.Now() + count++ if count%5000 == 0 { log.Infof("%d existing buckets", leaky.LeakyRoutineCount) - //when in forensics mode, garbage collect buckets + // when in forensics mode, garbage collect buckets if cConfig.Crowdsec.BucketsGCEnabled { if parsed.MarshaledTime != "" { z := &time.Time{} if err := z.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("Failed to unmarshal time from event '%s' : %s", parsed.MarshaledTime, err) + log.Warningf("Failed to parse time from event '%s' : %s", parsed.MarshaledTime, err) } else { log.Warning("Starting buckets garbage collection ...") + if err = leaky.GarbageCollectBuckets(*z, buckets); err != nil { - return fmt.Errorf("failed to start bucket GC : %s", err) + return fmt.Errorf("failed to start bucket GC : %w", err) } } } } } - //here we can bucketify with parsed + // here we can bucketify with parsed poured, err := leaky.PourItemToHolders(parsed, holders, buckets) if err != nil { log.Errorf("bucketify failed for: %v", parsed) continue } + elapsed := time.Since(startTime) globalPourHistogram.With(prometheus.Labels{"type": parsed.Line.Module, "source": parsed.Line.Src}).Observe(elapsed.Seconds()) + if poured { globalBucketPourOk.Inc() } else { globalBucketPourKo.Inc() } - if len(parsed.MarshaledTime) != 0 { + + if parsed.MarshaledTime != "" { if err := lastProcessedItem.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("failed to unmarshal time from event : %s", err) + log.Warningf("failed to parse time from event : %s", err) } } } diff --git a/cmd/crowdsec/run_in_svc.go b/cmd/crowdsec/run_in_svc.go index 2020537908d..288b565e890 100644 --- a/cmd/crowdsec/run_in_svc.go +++ b/cmd/crowdsec/run_in_svc.go @@ -3,6 +3,7 @@ package main import ( + "context" "fmt" "runtime/pprof" @@ -23,8 +24,8 @@ func StartRunSvc() error { defer trace.CatchPanic("crowdsec/StartRunSvc") - //Always try to stop CPU profiling to avoid passing flags around - //It's a noop if profiling is not enabled + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled defer pprof.StopCPUProfile() if cConfig, err = LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false); err != nil { @@ -33,7 +34,6 @@ func StartRunSvc() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early @@ -42,18 +42,24 @@ func StartRunSvc() error { var err error - if cConfig.DbConfig != nil { - dbClient, err = database.NewClient(cConfig.DbConfig) + ctx := context.TODO() + if cConfig.DbConfig != nil { + dbClient, err = database.NewClient(ctx, cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) + } else { + // avoid leaking the channel + go func() { + <-agentReady + }() } - return Serve(cConfig, apiReady, agentReady) + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/run_in_svc_windows.go b/cmd/crowdsec/run_in_svc_windows.go index 991f7ae4491..a2a2dd8c47a 100644 --- a/cmd/crowdsec/run_in_svc_windows.go +++ b/cmd/crowdsec/run_in_svc_windows.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "runtime/pprof" @@ -20,8 +21,8 @@ func StartRunSvc() error { defer trace.CatchPanic("crowdsec/StartRunSvc") - //Always try to stop CPU profiling to avoid passing flags around - //It's a noop if profiling is not enabled + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled defer pprof.StopCPUProfile() isRunninginService, err := svc.IsWindowsService() @@ -73,7 +74,6 @@ func WindowsRun() error { log.Infof("Crowdsec %s", version.String()) - apiReady := make(chan bool, 1) agentReady := make(chan bool, 1) // Enable profiling early @@ -81,15 +81,17 @@ func WindowsRun() error { var dbClient *database.Client var err error + ctx := context.TODO() + if cConfig.DbConfig != nil { - dbClient, err = database.NewClient(cConfig.DbConfig) + dbClient, err = database.NewClient(ctx, cConfig.DbConfig) if err != nil { - return fmt.Errorf("unable to create database client: %s", err) + return fmt.Errorf("unable to create database client: %w", err) } } registerPrometheus(cConfig.Prometheus) - go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) + go servePrometheus(cConfig.Prometheus, dbClient, agentReady) } - return Serve(cConfig, apiReady, agentReady) + return Serve(cConfig, agentReady) } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index a5c8e24cf3f..14602c425fe 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "os/signal" @@ -42,13 +43,17 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { if err := leaky.ShutdownAllBuckets(buckets); err != nil { log.Warningf("Failed to shut down routines : %s", err) } + log.Printf("Shutdown is finished, buckets are in %s", tmpFile) + return nil } func reloadHandler(sig os.Signal) (*csconfig.Config, error) { var tmpFile string + ctx := context.TODO() + // re-initialize tombs acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} @@ -57,6 +62,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { apiTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{} + lpMetricsTomb = tomb.Tomb{} cConfig, err := LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false) if err != nil { @@ -66,24 +72,29 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { if !cConfig.DisableAPI { if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } - apiReady := make(chan bool, 1) - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } if !cConfig.DisableAgent { - hub, err := cwhub.NewHub(cConfig.Hub, nil, false, log.StandardLogger()) + hub, err := cwhub.NewHub(cConfig.Hub, nil, log.StandardLogger()) if err != nil { - return nil, fmt.Errorf("while loading hub index: %w", err) + return nil, err + } + + if err = hub.Load(); err != nil { + return nil, err } - csParsers, err := initCrowdsec(cConfig, hub) + csParsers, datasources, err := initCrowdsec(cConfig, hub) if err != nil { return nil, fmt.Errorf("unable to init crowdsec: %w", err) } @@ -100,7 +111,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { } agentReady := make(chan bool, 1) - serveCrowdsec(csParsers, cConfig, hub, agentReady) + serveCrowdsec(csParsers, cConfig, hub, datasources, agentReady) } log.Printf("Reload is finished") @@ -110,6 +121,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { log.Warningf("Failed to delete temp file (%s) : %s", tmpFile, err) } } + return cConfig, nil } @@ -117,10 +129,12 @@ func ShutdownCrowdsecRoutines() error { var reterr error log.Debugf("Shutting down crowdsec sub-routines") + if len(dataSources) > 0 { acquisTomb.Kill(nil) log.Debugf("waiting for acquisition to finish") drainChan(inputLineChan) + if err := acquisTomb.Wait(); err != nil { log.Warningf("Acquisition returned error : %s", err) reterr = err @@ -130,6 +144,7 @@ func ShutdownCrowdsecRoutines() error { log.Debugf("acquisition is finished, wait for parser/bucket/ouputs.") parsersTomb.Kill(nil) drainChan(inputEventChan) + if err := parsersTomb.Wait(); err != nil { log.Warningf("Parsers returned error : %s", err) reterr = err @@ -160,15 +175,28 @@ func ShutdownCrowdsecRoutines() error { log.Warningf("Outputs returned error : %s", err) reterr = err } + log.Debugf("outputs are done") case <-time.After(3 * time.Second): // this can happen if outputs are stuck in a http retry loop log.Warningf("Outputs didn't finish in time, some events may have not been flushed") } + lpMetricsTomb.Kill(nil) + + if err := lpMetricsTomb.Wait(); err != nil { + log.Warningf("Metrics returned error : %s", err) + reterr = err + } + + log.Debugf("metrics are done") + // He's dead, Jim. crowdsecTomb.Kill(nil) + // close the potential geoips reader we have to avoid leaking ressources on reload + exprhelpers.GeoIPClose() + return reterr } @@ -181,6 +209,7 @@ func shutdownAPI() error { } log.Debugf("done") + return nil } @@ -193,6 +222,7 @@ func shutdownCrowdsec() error { } log.Debugf("done") + return nil } @@ -220,7 +250,7 @@ func drainChan(c chan types.Event) { for { select { case _, ok := <-c: - if !ok { //closed + if !ok { // closed return } default: @@ -246,8 +276,8 @@ func HandleSignals(cConfig *csconfig.Config) error { exitChan := make(chan error) - //Always try to stop CPU profiling to avoid passing flags around - //It's a noop if profiling is not enabled + // Always try to stop CPU profiling to avoid passing flags around + // It's a noop if profiling is not enabled defer pprof.StopCPUProfile() go func() { @@ -292,10 +322,11 @@ func HandleSignals(cConfig *csconfig.Config) error { if err == nil { log.Warning("Crowdsec service shutting down") } + return err } -func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) error { +func Serve(cConfig *csconfig.Config, agentReady chan bool) error { acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} bucketsTomb = tomb.Tomb{} @@ -303,9 +334,12 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e apiTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{} + lpMetricsTomb = tomb.Tomb{} + + ctx := context.TODO() if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { - dbClient, err := database.NewClient(cConfig.API.Server.DbConfig) + dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig) if err != nil { return fmt.Errorf("failed to get database client: %w", err) } @@ -323,8 +357,9 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e log.Warningln("Exprhelpers loaded without database client.") } - if cConfig.API.CTI != nil && *cConfig.API.CTI.Enabled { + if cConfig.API.CTI != nil && cConfig.API.CTI.Enabled != nil && *cConfig.API.CTI.Enabled { log.Infof("Crowdsec CTI helper enabled") + if err := exprhelpers.InitCrowdsecCTI(cConfig.API.CTI.Key, cConfig.API.CTI.CacheTimeout, cConfig.API.CTI.CacheSize, cConfig.API.CTI.LogLevel); err != nil { return fmt.Errorf("failed to init crowdsec cti: %w", err) } @@ -337,35 +372,40 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if flags.DisableCAPI { log.Warningf("Communication with CrowdSec Central API disabled from args") + cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return fmt.Errorf("api server init: %w", err) } if !flags.TestMode { - serveAPIServer(apiServer, apiReady) + serveAPIServer(apiServer) } - } else { - apiReady <- true } if !cConfig.DisableAgent { - hub, err := cwhub.NewHub(cConfig.Hub, nil, false, log.StandardLogger()) + hub, err := cwhub.NewHub(cConfig.Hub, nil, log.StandardLogger()) if err != nil { - return fmt.Errorf("while loading hub index: %w", err) + return err + } + + if err = hub.Load(); err != nil { + return err } - csParsers, err := initCrowdsec(cConfig, hub) + csParsers, datasources, err := initCrowdsec(cConfig, hub) if err != nil { return fmt.Errorf("crowdsec init: %w", err) } // if it's just linting, we're done if !flags.TestMode { - serveCrowdsec(csParsers, cConfig, hub, agentReady) + serveCrowdsec(csParsers, cConfig, hub, datasources, agentReady) + } else { + agentReady <- true } } else { agentReady <- true @@ -374,11 +414,12 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if flags.TestMode { log.Infof("Configuration test done") pluginBroker.Kill() - os.Exit(0) + + return nil } if cConfig.Common != nil && cConfig.Common.Daemonize { - csdaemon.NotifySystemd(log.StandardLogger()) + csdaemon.Notify(csdaemon.Ready, log.StandardLogger()) // wait for signals return HandleSignals(cConfig) } @@ -395,6 +436,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e for _, ch := range waitChans { <-ch + switch ch { case apiTomb.Dead(): log.Infof("api shutdown") @@ -402,5 +444,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e log.Infof("crowdsec shutdown") } } + return nil } diff --git a/cmd/notification-dummy/main.go b/cmd/notification-dummy/main.go index ef8d29ffa44..7fbb10d4fca 100644 --- a/cmd/notification-dummy/main.go +++ b/cmd/notification-dummy/main.go @@ -5,10 +5,12 @@ import ( "fmt" "os" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { @@ -18,6 +20,7 @@ type PluginConfig struct { } type DummyPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -32,6 +35,7 @@ func (s *DummyPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -42,19 +46,22 @@ func (s *DummyPlugin) Notify(ctx context.Context, notification *protobufs.Notifi logger.Debug(notification.Text) if cfg.OutputFile != nil && *cfg.OutputFile != "" { - f, err := os.OpenFile(*cfg.OutputFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(*cfg.OutputFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { logger.Error(fmt.Sprintf("Cannot open notification file: %s", err)) } + if _, err := f.WriteString(notification.Text + "\n"); err != nil { f.Close() logger.Error(fmt.Sprintf("Cannot write notification to file: %s", err)) } + err = f.Close() if err != nil { logger.Error(fmt.Sprintf("Cannot close notification file: %s", err)) } } + fmt.Println(notification.Text) return &protobufs.Empty{}, nil @@ -64,11 +71,12 @@ func (s *DummyPlugin) Configure(ctx context.Context, config *protobufs.Config) ( d := PluginConfig{} err := yaml.Unmarshal(config.Config, &d) s.PluginConfigByName[d.Name] = d + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -78,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "dummy": &protobufs.NotifierPlugin{ + "dummy": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-email/main.go b/cmd/notification-email/main.go index 789740156fe..5fc02cdd1d7 100644 --- a/cmd/notification-email/main.go +++ b/cmd/notification-email/main.go @@ -2,15 +2,18 @@ package main import ( "context" + "errors" "fmt" "os" "time" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" mail "github.com/xhit/go-simple-mail/v2" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) var baseLogger hclog.Logger = hclog.New(&hclog.LoggerOptions{ @@ -53,6 +56,7 @@ type PluginConfig struct { } type EmailPlugin struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -72,19 +76,20 @@ func (n *EmailPlugin) Configure(ctx context.Context, config *protobufs.Config) ( } if d.Name == "" { - return nil, fmt.Errorf("name is required") + return nil, errors.New("name is required") } if d.SMTPHost == "" { - return nil, fmt.Errorf("SMTP host is not set") + return nil, errors.New("SMTP host is not set") } - if d.ReceiverEmails == nil || len(d.ReceiverEmails) == 0 { - return nil, fmt.Errorf("receiver emails are not set") + if len(d.ReceiverEmails) == 0 { + return nil, errors.New("receiver emails are not set") } n.ConfigByName[d.Name] = d baseLogger.Debug(fmt.Sprintf("Email plugin '%s' use SMTP host '%s:%d'", d.Name, d.SMTPHost, d.SMTPPort)) + return &protobufs.Empty{}, nil } @@ -92,6 +97,7 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if _, ok := n.ConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := n.ConfigByName[notification.Name] logger := baseLogger.Named(cfg.Name) @@ -117,6 +123,7 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi server.ConnectTimeout, err = time.ParseDuration(cfg.ConnectTimeout) if err != nil { logger.Warn(fmt.Sprintf("invalid connect timeout '%s', using default '10s'", cfg.ConnectTimeout)) + server.ConnectTimeout = 10 * time.Second } } @@ -125,15 +132,18 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi server.SendTimeout, err = time.ParseDuration(cfg.SendTimeout) if err != nil { logger.Warn(fmt.Sprintf("invalid send timeout '%s', using default '10s'", cfg.SendTimeout)) + server.SendTimeout = 10 * time.Second } } logger.Debug("making smtp connection") + smtpClient, err := server.Connect() if err != nil { return &protobufs.Empty{}, err } + logger.Debug("smtp connection done") email := mail.NewMSG() @@ -146,12 +156,14 @@ func (n *EmailPlugin) Notify(ctx context.Context, notification *protobufs.Notifi if err != nil { return &protobufs.Empty{}, err } + logger.Info(fmt.Sprintf("sent email to %v", cfg.ReceiverEmails)) + return &protobufs.Empty{}, nil } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -160,7 +172,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "email": &protobufs.NotifierPlugin{ + "email": &csplugin.NotifierPlugin{ Impl: &EmailPlugin{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/cmd/notification-file/Makefile b/cmd/notification-file/Makefile new file mode 100644 index 00000000000..4504328c49a --- /dev/null +++ b/cmd/notification-file/Makefile @@ -0,0 +1,17 @@ +ifeq ($(OS), Windows_NT) + SHELL := pwsh.exe + .SHELLFLAGS := -NoProfile -Command + EXT = .exe +endif + +GO = go +GOBUILD = $(GO) build + +BINARY_NAME = notification-file$(EXT) + +build: clean + $(GOBUILD) $(LD_OPTS) -o $(BINARY_NAME) + +.PHONY: clean +clean: + @$(RM) $(BINARY_NAME) $(WIN_IGNORE_ERR) diff --git a/cmd/notification-file/file.yaml b/cmd/notification-file/file.yaml new file mode 100644 index 00000000000..61c77b9eb49 --- /dev/null +++ b/cmd/notification-file/file.yaml @@ -0,0 +1,23 @@ +# Don't change this +type: file + +name: file_default # this must match with the registered plugin in the profile +log_level: info # Options include: trace, debug, info, warn, error, off + +# This template render all events as ndjson +format: | + {{range . -}} + { "time": "{{.StopAt}}", "program": "crowdsec", "alert": {{. | toJson }} } + {{ end -}} + +# group_wait: # duration to wait collecting alerts before sending to this plugin, eg "30s" +# group_threshold: # if alerts exceed this, then the plugin will be sent the message. eg "10" + +#Use full path EG /tmp/crowdsec_alerts.json or %TEMP%\crowdsec_alerts.json +log_path: "/tmp/crowdsec_alerts.json" +rotate: + enabled: true # Change to false if you want to handle log rotate on system basis + max_size: 500 # in MB + max_files: 5 + max_age: 5 + compress: true diff --git a/cmd/notification-file/main.go b/cmd/notification-file/main.go new file mode 100644 index 00000000000..a4dbb8ee5db --- /dev/null +++ b/cmd/notification-file/main.go @@ -0,0 +1,253 @@ +package main + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + plugin "github.com/hashicorp/go-plugin" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" +) + +var ( + FileWriter *os.File + FileWriteMutex *sync.Mutex + FileSize int64 +) + +type FileWriteCtx struct { + Ctx context.Context + Writer io.Writer +} + +func (w *FileWriteCtx) Write(p []byte) (n int, err error) { + if err := w.Ctx.Err(); err != nil { + return 0, err + } + return w.Writer.Write(p) +} + +type PluginConfig struct { + Name string `yaml:"name"` + LogLevel string `yaml:"log_level"` + LogPath string `yaml:"log_path"` + LogRotate LogRotate `yaml:"rotate"` +} + +type LogRotate struct { + MaxSize int `yaml:"max_size"` + MaxAge int `yaml:"max_age"` + MaxFiles int `yaml:"max_files"` + Enabled bool `yaml:"enabled"` + Compress bool `yaml:"compress"` +} + +type FilePlugin struct { + protobufs.UnimplementedNotifierServer + PluginConfigByName map[string]PluginConfig +} + +var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ + Name: "file-plugin", + Level: hclog.LevelFromString("INFO"), + Output: os.Stderr, + JSONFormat: true, +}) + +func (r *LogRotate) rotateLogs(cfg PluginConfig) { + // Rotate the log file + err := r.rotateLogFile(cfg.LogPath, r.MaxFiles) + if err != nil { + logger.Error("Failed to rotate log file", "error", err) + } + // Reopen the FileWriter + FileWriter.Close() + FileWriter, err = os.OpenFile(cfg.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + logger.Error("Failed to reopen log file", "error", err) + } + // Reset the file size + FileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + } + FileSize = FileInfo.Size() +} + +func (r *LogRotate) rotateLogFile(logPath string, maxBackups int) error { + // Rename the current log file + backupPath := logPath + "." + time.Now().Format("20060102-150405") + err := os.Rename(logPath, backupPath) + if err != nil { + return err + } + glob := logPath + ".*" + if r.Compress { + glob = logPath + ".*.gz" + err = compressFile(backupPath) + if err != nil { + return err + } + } + + // Remove old backups + files, err := filepath.Glob(glob) + if err != nil { + return err + } + + sort.Sort(sort.Reverse(sort.StringSlice(files))) + + for i, file := range files { + logger.Trace("Checking file", "file", file, "index", i, "maxBackups", maxBackups) + if i >= maxBackups { + logger.Trace("Removing file as over max backup count", "file", file) + os.Remove(file) + } else { + // Check the age of the file + fileInfo, err := os.Stat(file) + if err != nil { + return err + } + age := time.Since(fileInfo.ModTime()).Hours() + if age > float64(r.MaxAge*24) { + logger.Trace("Removing file as age was over configured amount", "file", file, "age", age) + os.Remove(file) + } + } + } + + return nil +} + +func compressFile(src string) error { + // Open the source file for reading + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + // Create the destination file + dstFile, err := os.Create(src + ".gz") + if err != nil { + return err + } + defer dstFile.Close() + + // Create a gzip writer + gw := gzip.NewWriter(dstFile) + defer gw.Close() + + // Read the source file and write its contents to the gzip writer + _, err = io.Copy(gw, srcFile) + if err != nil { + return err + } + + // Delete the original (uncompressed) backup file + err = os.Remove(src) + if err != nil { + return err + } + + return nil +} + +func WriteToFileWithCtx(ctx context.Context, cfg PluginConfig, log string) error { + FileWriteMutex.Lock() + defer FileWriteMutex.Unlock() + originalFileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + } + currentFileInfo, _ := os.Stat(cfg.LogPath) + if !os.SameFile(originalFileInfo, currentFileInfo) { + // The file has been rotated outside our control + logger.Info("Log file has been rotated or missing attempting to reopen it") + FileWriter.Close() + FileWriter, err = os.OpenFile(cfg.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return err + } + FileInfo, err := FileWriter.Stat() + if err != nil { + return err + } + FileSize = FileInfo.Size() + logger.Info("Log file has been reopened successfully") + } + n, err := io.WriteString(&FileWriteCtx{Ctx: ctx, Writer: FileWriter}, log) + if err == nil { + FileSize += int64(n) + if FileSize > int64(cfg.LogRotate.MaxSize)*1024*1024 && cfg.LogRotate.Enabled { + logger.Debug("Rotating log file", "file", cfg.LogPath) + // Rotate the log file + cfg.LogRotate.rotateLogs(cfg) + } + } + return err +} + +func (s *FilePlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { + if _, ok := s.PluginConfigByName[notification.Name]; !ok { + return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) + } + cfg := s.PluginConfigByName[notification.Name] + + return &protobufs.Empty{}, WriteToFileWithCtx(ctx, cfg, notification.Text) +} + +func (s *FilePlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { + d := PluginConfig{} + err := yaml.Unmarshal(config.Config, &d) + if err != nil { + logger.Error("Failed to parse config", "error", err) + return &protobufs.Empty{}, err + } + FileWriteMutex = &sync.Mutex{} + FileWriter, err = os.OpenFile(d.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + logger.Error("Failed to open log file", "error", err) + return &protobufs.Empty{}, err + } + FileInfo, err := FileWriter.Stat() + if err != nil { + logger.Error("Failed to get file info", "error", err) + return &protobufs.Empty{}, err + } + FileSize = FileInfo.Size() + s.PluginConfigByName[d.Name] = d + logger.SetLevel(hclog.LevelFromString(d.LogLevel)) + return &protobufs.Empty{}, err +} + +func main() { + handshake := plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "CROWDSEC_PLUGIN_KEY", + MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), + } + + sp := &FilePlugin{PluginConfigByName: make(map[string]PluginConfig)} + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshake, + Plugins: map[string]plugin.Plugin{ + "file": &csplugin.NotifierPlugin{ + Impl: sp, + }, + }, + GRPCServer: plugin.DefaultGRPCServer, + Logger: logger, + }) +} diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 340d462c175..3f84984315b 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -7,18 +7,23 @@ import ( "crypto/x509" "fmt" "io" + "net" "net/http" "os" + "strings" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { Name string `yaml:"name"` URL string `yaml:"url"` + UnixSocket string `yaml:"unix_socket"` Headers map[string]string `yaml:"headers"` SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` @@ -30,6 +35,7 @@ type PluginConfig struct { } type HTTPPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -66,42 +72,52 @@ func getCertPool(caPath string) (*x509.CertPool, error) { return cp, nil } -func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) { - var client *http.Client - - caCertPool, err := getCertPool(caPath) +func getTLSClient(c *PluginConfig) error { + caCertPool, err := getCertPool(c.CAPath) if err != nil { - return nil, err + return err } tlsConfig := &tls.Config{ RootCAs: caCertPool, - InsecureSkipVerify: tlsVerify, + InsecureSkipVerify: c.SkipTLSVerification, } - if certPath != "" && keyPath != "" { - logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath)) + if c.CertPath != "" && c.KeyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath)) - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) if err != nil { - return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err) + return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err) } tlsConfig.Certificates = []tls.Certificate{cert} } - client = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + if c.UnixSocket != "" { + logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket)) + + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/")) + } + } + + c.Client = &http.Client{ + Transport: transport, } - return client, err + + return nil } func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -114,11 +130,14 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific if err != nil { return nil, err } + for headerName, headerValue := range cfg.Headers { logger.Debug(fmt.Sprintf("adding header %s: %s", headerName, headerValue)) request.Header.Add(headerName, headerValue) } + logger.Debug(fmt.Sprintf("making HTTP %s call to %s with body %s", cfg.Method, cfg.URL, notification.Text)) + resp, err := cfg.Client.Do(request.WithContext(ctx)) if err != nil { logger.Error(fmt.Sprintf("Failed to make HTTP request : %s", err)) @@ -128,13 +147,15 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific respData, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response body got error %s", err) + return nil, fmt.Errorf("failed to read response body got error %w", err) } logger.Debug(fmt.Sprintf("got response %s", string(respData))) if resp.StatusCode < 200 || resp.StatusCode >= 300 { logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode)) + logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData))) + return &protobufs.Empty{}, nil } @@ -143,21 +164,25 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { d := PluginConfig{} + err := yaml.Unmarshal(config.Config, &d) if err != nil { return nil, err } - d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath) + + err = getTLSClient(&d) if err != nil { return nil, err } + s.PluginConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("HTTP plugin '%s' use URL '%s'", d.Name, d.URL)) + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -167,7 +192,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "http": &protobufs.NotifierPlugin{ + "http": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-sentinel/main.go b/cmd/notification-sentinel/main.go index c627f9271e2..0293d45b0a4 100644 --- a/cmd/notification-sentinel/main.go +++ b/cmd/notification-sentinel/main.go @@ -11,10 +11,12 @@ import ( "strings" "time" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { @@ -26,6 +28,7 @@ type PluginConfig struct { } type SentinelPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -37,7 +40,7 @@ var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ }) func (s *SentinelPlugin) getAuthorizationHeader(now string, length int, pluginName string) (string, error) { - xHeaders := "x-ms-date:" + now + xHeaders := "X-Ms-Date:" + now stringToHash := fmt.Sprintf("POST\n%d\napplication/json\n%s\n/api/logs", length, xHeaders) decodedKey, _ := base64.StdEncoding.DecodeString(s.PluginConfigByName[pluginName].SharedKey) @@ -54,7 +57,6 @@ func (s *SentinelPlugin) getAuthorizationHeader(now string, length int, pluginNa } func (s *SentinelPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { - if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } @@ -73,7 +75,6 @@ func (s *SentinelPlugin) Notify(ctx context.Context, notification *protobufs.Not now := time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT") authorization, err := s.getAuthorizationHeader(now, len(notification.Text), notification.Name) - if err != nil { return &protobufs.Empty{}, err } @@ -87,7 +88,7 @@ func (s *SentinelPlugin) Notify(ctx context.Context, notification *protobufs.Not req.Header.Set("Content-Type", "application/json") req.Header.Set("Log-Type", s.PluginConfigByName[notification.Name].LogType) req.Header.Set("Authorization", authorization) - req.Header.Set("x-ms-date", now) + req.Header.Set("X-Ms-Date", now) client := &http.Client{} resp, err := client.Do(req.WithContext(ctx)) @@ -113,7 +114,7 @@ func (s *SentinelPlugin) Configure(ctx context.Context, config *protobufs.Config } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -123,7 +124,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "sentinel": &protobufs.NotifierPlugin{ + "sentinel": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-slack/main.go b/cmd/notification-slack/main.go index 373cd9527ab..34c7c0df361 100644 --- a/cmd/notification-slack/main.go +++ b/cmd/notification-slack/main.go @@ -5,20 +5,26 @@ import ( "fmt" "os" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" - "github.com/slack-go/slack" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) type PluginConfig struct { - Name string `yaml:"name"` - Webhook string `yaml:"webhook"` - LogLevel *string `yaml:"log_level"` + Name string `yaml:"name"` + Webhook string `yaml:"webhook"` + Channel string `yaml:"channel"` + Username string `yaml:"username"` + IconEmoji string `yaml:"icon_emoji"` + IconURL string `yaml:"icon_url"` + LogLevel *string `yaml:"log_level"` } type Notify struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -33,15 +39,22 @@ func (n *Notify) Notify(ctx context.Context, notification *protobufs.Notificatio if _, ok := n.ConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) } + cfg := n.ConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { logger.SetLevel(hclog.LevelFromString(*cfg.LogLevel)) } + logger.Info(fmt.Sprintf("found notify signal for %s config", notification.Name)) logger.Debug(fmt.Sprintf("posting to %s webhook, message %s", cfg.Webhook, notification.Text)) - err := slack.PostWebhookContext(ctx, n.ConfigByName[notification.Name].Webhook, &slack.WebhookMessage{ - Text: notification.Text, + + err := slack.PostWebhookContext(ctx, cfg.Webhook, &slack.WebhookMessage{ + Text: notification.Text, + Channel: cfg.Channel, + Username: cfg.Username, + IconEmoji: cfg.IconEmoji, + IconURL: cfg.IconURL, }) if err != nil { logger.Error(err.Error()) @@ -52,16 +65,19 @@ func (n *Notify) Notify(ctx context.Context, notification *protobufs.Notificatio func (n *Notify) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { d := PluginConfig{} + if err := yaml.Unmarshal(config.Config, &d); err != nil { return nil, err } + n.ConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("Slack plugin '%s' use URL '%s'", d.Name, d.Webhook)) + return &protobufs.Empty{}, nil } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -70,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "slack": &protobufs.NotifierPlugin{ + "slack": &csplugin.NotifierPlugin{ Impl: &Notify{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/cmd/notification-slack/slack.yaml b/cmd/notification-slack/slack.yaml index 4768e869780..677d4b757c1 100644 --- a/cmd/notification-slack/slack.yaml +++ b/cmd/notification-slack/slack.yaml @@ -28,6 +28,12 @@ format: | webhook: +# API request data as defined by the Slack webhook API. +#channel: +#username: +#icon_emoji: +#icon_url: + --- # type: slack diff --git a/cmd/notification-splunk/main.go b/cmd/notification-splunk/main.go index b24aa538f9a..e18f416c14a 100644 --- a/cmd/notification-splunk/main.go +++ b/cmd/notification-splunk/main.go @@ -10,11 +10,12 @@ import ( "os" "strings" - "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" + "gopkg.in/yaml.v3" - "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ @@ -32,6 +33,7 @@ type PluginConfig struct { } type Splunk struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig Client http.Client } @@ -44,6 +46,7 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio if _, ok := s.PluginConfigByName[notification.Name]; !ok { return &protobufs.Empty{}, fmt.Errorf("splunk invalid config name %s", notification.Name) } + cfg := s.PluginConfigByName[notification.Name] if cfg.LogLevel != nil && *cfg.LogLevel != "" { @@ -53,6 +56,7 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio logger.Info(fmt.Sprintf("received notify signal for %s config", notification.Name)) p := Payload{Event: notification.Text} + data, err := json.Marshal(p) if err != nil { return &protobufs.Empty{}, err @@ -65,6 +69,7 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio req.Header.Add("Authorization", fmt.Sprintf("Splunk %s", cfg.Token)) logger.Debug(fmt.Sprintf("posting event %s to %s", string(data), req.URL)) + resp, err := s.Client.Do(req.WithContext(ctx)) if err != nil { return &protobufs.Empty{}, err @@ -73,15 +78,19 @@ func (s *Splunk) Notify(ctx context.Context, notification *protobufs.Notificatio if resp.StatusCode != http.StatusOK { content, err := io.ReadAll(resp.Body) if err != nil { - return &protobufs.Empty{}, fmt.Errorf("got non 200 response and failed to read error %s", err) + return &protobufs.Empty{}, fmt.Errorf("got non 200 response and failed to read error %w", err) } + return &protobufs.Empty{}, fmt.Errorf("got non 200 response %s", string(content)) } + respData, err := io.ReadAll(resp.Body) if err != nil { - return &protobufs.Empty{}, fmt.Errorf("failed to read response body got error %s", err) + return &protobufs.Empty{}, fmt.Errorf("failed to read response body got error %w", err) } + logger.Debug(fmt.Sprintf("got response %s", string(respData))) + return &protobufs.Empty{}, nil } @@ -90,11 +99,12 @@ func (s *Splunk) Configure(ctx context.Context, config *protobufs.Config) (*prot err := yaml.Unmarshal(config.Config, &d) s.PluginConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("Splunk plugin '%s' use URL '%s'", d.Name, d.URL)) + return &protobufs.Empty{}, err } func main() { - var handshake = plugin.HandshakeConfig{ + handshake := plugin.HandshakeConfig{ ProtocolVersion: 1, MagicCookieKey: "CROWDSEC_PLUGIN_KEY", MagicCookieValue: os.Getenv("CROWDSEC_PLUGIN_KEY"), @@ -109,7 +119,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "splunk": &protobufs.NotifierPlugin{ + "splunk": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/config/crowdsec.cron.daily b/config/crowdsec.cron.daily index 1c110df38fc..9c488d29884 100644 --- a/config/crowdsec.cron.daily +++ b/config/crowdsec.cron.daily @@ -2,12 +2,13 @@ test -x /usr/bin/cscli || exit 0 +# splay hub upgrade and crowdsec reload +sleep "$(seq 1 300 | shuf -n 1)" + /usr/bin/cscli --error hub update upgraded=$(/usr/bin/cscli --error hub upgrade) if [ -n "$upgraded" ]; then - # splay initial metrics push - sleep $(seq 1 90 | shuf -n 1) systemctl reload crowdsec fi diff --git a/config/crowdsec.service b/config/crowdsec.service index 147cae4946e..65a8d30bc5f 100644 --- a/config/crowdsec.service +++ b/config/crowdsec.service @@ -8,6 +8,7 @@ Environment=LC_ALL=C LANG=C ExecStartPre=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecStart=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml #ExecStartPost=/bin/sleep 0.1 +ExecReload=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=60 diff --git a/config/profiles.yaml b/config/profiles.yaml index 9d81c9298a2..c4982acd978 100644 --- a/config/profiles.yaml +++ b/config/profiles.yaml @@ -12,3 +12,18 @@ decisions: # - http_default # Set the required http parameters in /etc/crowdsec/notifications/http.yaml before enabling this. # - email_default # Set the required email parameters in /etc/crowdsec/notifications/email.yaml before enabling this. on_success: break +--- +name: default_range_remediation +#debug: true +filters: + - Alert.Remediation == true && Alert.GetScope() == "Range" +decisions: + - type: ban + duration: 4h +#duration_expr: Sprintf('%dh', (GetDecisionsCount(Alert.GetValue()) + 1) * 4) +# notifications: +# - slack_default # Set the webhook in /etc/crowdsec/notifications/slack.yaml before enabling this. +# - splunk_default # Set the splunk url and token in /etc/crowdsec/notifications/splunk.yaml before enabling this. +# - http_default # Set the required http parameters in /etc/crowdsec/notifications/http.yaml before enabling this. +# - email_default # Set the required email parameters in /etc/crowdsec/notifications/email.yaml before enabling this. +on_success: break diff --git a/debian/control b/debian/control index 4673284e7b4..0ee08b71f85 100644 --- a/debian/control +++ b/debian/control @@ -8,3 +8,4 @@ Package: crowdsec Architecture: any Description: Crowdsec - An open-source, lightweight agent to detect and respond to bad behaviors. It also automatically benefits from our global community-wide IP reputation database Depends: coreutils +Suggests: cron diff --git a/debian/crowdsec.service b/debian/crowdsec.service index b65558f70d3..c1a5e403745 100644 --- a/debian/crowdsec.service +++ b/debian/crowdsec.service @@ -8,6 +8,7 @@ Environment=LC_ALL=C LANG=C ExecStartPre=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecStart=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml #ExecStartPost=/bin/sleep 0.1 +ExecReload=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=60 diff --git a/debian/install b/debian/install index 3153244b8e9..fa422cac8d9 100644 --- a/debian/install +++ b/debian/install @@ -11,3 +11,4 @@ cmd/notification-http/http.yaml etc/crowdsec/notifications/ cmd/notification-splunk/splunk.yaml etc/crowdsec/notifications/ cmd/notification-email/email.yaml etc/crowdsec/notifications/ cmd/notification-sentinel/sentinel.yaml etc/crowdsec/notifications/ +cmd/notification-file/file.yaml etc/crowdsec/notifications/ diff --git a/debian/rules b/debian/rules index 655af3dfeea..c11771282ea 100755 --- a/debian/rules +++ b/debian/rules @@ -17,6 +17,7 @@ override_dh_auto_install: mkdir -p debian/crowdsec/usr/bin mkdir -p debian/crowdsec/etc/crowdsec + mkdir -p debian/crowdsec/etc/crowdsec/acquis.d mkdir -p debian/crowdsec/usr/share/crowdsec mkdir -p debian/crowdsec/etc/crowdsec/hub/ mkdir -p debian/crowdsec/usr/share/crowdsec/config @@ -30,6 +31,7 @@ override_dh_auto_install: install -m 551 cmd/notification-splunk/notification-splunk debian/crowdsec/usr/lib/crowdsec/plugins/ install -m 551 cmd/notification-email/notification-email debian/crowdsec/usr/lib/crowdsec/plugins/ install -m 551 cmd/notification-sentinel/notification-sentinel debian/crowdsec/usr/lib/crowdsec/plugins/ + install -m 551 cmd/notification-file/notification-file debian/crowdsec/usr/lib/crowdsec/plugins/ cp cmd/crowdsec/crowdsec debian/crowdsec/usr/bin cp cmd/crowdsec-cli/cscli debian/crowdsec/usr/bin diff --git a/docker/README.md b/docker/README.md index 5e39838a175..ad31d10aed6 100644 --- a/docker/README.md +++ b/docker/README.md @@ -333,6 +333,9 @@ config.yaml) each time the container is run. | `DISABLE_APPSEC_RULES` | | Appsec rules files to remove, separated by space | | | | | | __Log verbosity__ | | | +| `LEVEL_FATAL` | false | Force FATAL level for the container log | +| `LEVEL_ERROR` | false | Force ERROR level for the container log | +| `LEVEL_WARN` | false | Force WARN level for the container log | | `LEVEL_INFO` | false | Force INFO level for the container log | | `LEVEL_DEBUG` | false | Force DEBUG level for the container log | | `LEVEL_TRACE` | false | Force TRACE level (VERY verbose) for the container log | diff --git a/docker/docker_start.sh b/docker/docker_start.sh index dd96184ccbc..fb87c1eff9b 100755 --- a/docker/docker_start.sh +++ b/docker/docker_start.sh @@ -6,6 +6,9 @@ set -e shopt -s inherit_errexit +# Note that "if function_name" in bash matches when the function returns 0, +# meaning successful execution. + # match true, TRUE, True, tRuE, etc. istrue() { case "$(echo "$1" | tr '[:upper:]' '[:lower:]')" in @@ -50,6 +53,52 @@ cscli() { command cscli -c "$CONFIG_FILE" "$@" } +run_hub_update() { + index_modification_time=$(stat -c %Y /etc/crowdsec/hub/.index.json 2>/dev/null) + # Run cscli hub update if no date or if the index file is older than 24h + if [ -z "$index_modification_time" ] || [ $(( $(date +%s) - index_modification_time )) -gt 86400 ]; then + cscli hub update --with-content + else + echo "Skipping hub update, index file is recent" + fi +} + +is_mounted() { + path=$(readlink -f "$1") + mounts=$(awk '{print $2}' /proc/mounts) + while true; do + if grep -qE ^"$path"$ <<< "$mounts"; then + echo "$path was found in a volume" + return 0 + fi + path=$(dirname "$path") + if [ "$path" = "/" ]; then + return 1 + fi + done + return 1 #unreachable +} + +run_hub_update_if_from_volume() { + if is_mounted "/etc/crowdsec/hub/.index.json"; then + echo "Running hub update" + run_hub_update + else + echo "Skipping hub update, index file is not in a volume" + fi +} + +run_hub_upgrade_if_from_volume() { + isfalse "$NO_HUB_UPGRADE" || return 0 + if is_mounted "/var/lib/crowdsec/data"; then + echo "Running hub upgrade" + cscli hub upgrade + else + echo "Skipping hub upgrade, data directory is not in a volume" + fi + +} + # conf_get [file_path] # retrieve a value from a file (by default $CONFIG_FILE) conf_get() { @@ -119,7 +168,12 @@ cscli_if_clean() { error_only="" echo "Running: cscli $error_only $itemtype $action \"$obj\" $*" # shellcheck disable=SC2086 - cscli $error_only "$itemtype" "$action" "$obj" "$@" + if ! cscli $error_only "$itemtype" "$action" "$obj" "$@"; then + echo "Failed to $action $itemtype/$obj, running hub update before retrying" + run_hub_update + # shellcheck disable=SC2086 + cscli $error_only "$itemtype" "$action" "$obj" "$@" + fi fi done } @@ -159,15 +213,16 @@ if [ -n "$CERT_FILE" ] || [ -n "$KEY_FILE" ] ; then export LAPI_KEY_FILE=${LAPI_KEY_FILE:-$KEY_FILE} fi -# Check and prestage databases -for geodb in GeoLite2-ASN.mmdb GeoLite2-City.mmdb; do - # We keep the pre-populated geoip databases in /staging instead of /var, - # because if the data directory is bind-mounted from the host, it will be - # empty and the files will be out of reach, requiring a runtime download. - # We link to them to save about 80Mb compared to cp/mv. - if [ ! -e "/var/lib/crowdsec/data/$geodb" ] && [ -e "/staging/var/lib/crowdsec/data/$geodb" ]; then - mkdir -p /var/lib/crowdsec/data - ln -s "/staging/var/lib/crowdsec/data/$geodb" /var/lib/crowdsec/data/ +# Link the preloaded data files when the data dir is mounted (common case) +# The symlinks can be overridden by hub upgrade +for target in "/staging/var/lib/crowdsec/data"/*; do + fname="$(basename "$target")" + # skip the db and wal files + if [[ $fname == crowdsec.db* ]]; then + continue + fi + if [ ! -e "/var/lib/crowdsec/data/$fname" ]; then + ln -s "$target" "/var/lib/crowdsec/data/$fname" fi done @@ -279,10 +334,12 @@ fi # crowdsec sqlite database permissions if [ "$GID" != "" ]; then if istrue "$(conf_get '.db_config.type == "sqlite"')"; then + # force the creation of the db file(s) + cscli machines inspect create-db --error >/dev/null 2>&1 || : # don't fail if the db is not there yet - chown -f ":$GID" "$(conf_get '.db_config.db_path')" 2>/dev/null \ - && echo "sqlite database permissions updated" \ - || true + if chown -f ":$GID" "$(conf_get '.db_config.db_path')" 2>/dev/null; then + echo "sqlite database permissions updated" + fi fi fi @@ -304,11 +361,8 @@ conf_set_if "$PLUGIN_DIR" '.config_paths.plugin_dir = strenv(PLUGIN_DIR)' ## Install hub items -cscli hub update || true - -if isfalse "$NO_HUB_UPGRADE"; then - cscli hub upgrade || true -fi +run_hub_update_if_from_volume || true +run_hub_upgrade_if_from_volume || true cscli_if_clean parsers install crowdsecurity/docker-logs cscli_if_clean parsers install crowdsecurity/cri-logs @@ -453,5 +507,17 @@ if istrue "$LEVEL_INFO"; then ARGS="$ARGS -info" fi +if istrue "$LEVEL_WARN"; then + ARGS="$ARGS -warning" +fi + +if istrue "$LEVEL_ERROR"; then + ARGS="$ARGS -error" +fi + +if istrue "$LEVEL_FATAL"; then + ARGS="$ARGS -fatal" +fi + # shellcheck disable=SC2086 exec crowdsec $ARGS diff --git a/docker/preload-hub-items b/docker/preload-hub-items new file mode 100755 index 00000000000..45155d17af9 --- /dev/null +++ b/docker/preload-hub-items @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -eu + +# pre-download everything but don't install anything + +echo "Pre-downloading Hub content..." + +types=$(cscli hub types -o raw) + +for itemtype in $types; do + ALL_ITEMS=$(cscli "$itemtype" list -a -o json | itemtype="$itemtype" yq '.[env(itemtype)][] | .name') + if [[ -n "${ALL_ITEMS}" ]]; then + #shellcheck disable=SC2086 + cscli "$itemtype" install \ + $ALL_ITEMS \ + --download-only \ + --error + fi +done + +echo " done." \ No newline at end of file diff --git a/docker/test/Pipfile.lock b/docker/test/Pipfile.lock index 75437876b72..99184d9f2a2 100644 --- a/docker/test/Pipfile.lock +++ b/docker/test/Pipfile.lock @@ -18,69 +18,84 @@ "default": { "certifi": { "hashes": [ - "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1", - "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474" + "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", + "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9" ], "markers": "python_version >= '3.6'", - "version": "==2023.11.17" + "version": "==2024.8.30" }, "cffi": { "hashes": [ - "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc", - "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a", - "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417", - "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab", - "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520", - "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36", - "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743", - "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8", - "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed", - "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684", - "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56", - "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324", - "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d", - "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235", - "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e", - "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088", - "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000", - "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7", - "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e", - "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673", - "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c", - "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe", - "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2", - "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098", - "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8", - "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a", - "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0", - "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b", - "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896", - "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e", - "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9", - "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2", - "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b", - "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6", - "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404", - "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f", - "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0", - "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4", - "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc", - "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936", - "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba", - "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872", - "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb", - "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614", - "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1", - "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d", - "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969", - "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b", - "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4", - "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627", - "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956", - "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357" + "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", + "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", + "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1", + "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", + "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", + "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", + "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8", + "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36", + "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", + "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", + "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc", + "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", + "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", + "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", + "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", + "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", + "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", + "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", + "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", + "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b", + "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", + "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", + "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c", + "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", + "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", + "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", + "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8", + "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1", + "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", + "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", + "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", + "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595", + "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0", + "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", + "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", + "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", + "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", + "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", + "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", + "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16", + "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", + "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e", + "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", + "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964", + "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", + "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576", + "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", + "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3", + "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662", + "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", + "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", + "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", + "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", + "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", + "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", + "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", + "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", + "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9", + "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7", + "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", + "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a", + "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", + "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", + "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", + "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", + "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87", + "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b" ], "markers": "platform_python_implementation != 'PyPy'", - "version": "==1.16.0" + "version": "==1.17.1" }, "charset-normalizer": { "hashes": [ @@ -180,65 +195,60 @@ }, "cryptography": { "hashes": [ - "sha256:087887e55e0b9c8724cf05361357875adb5c20dec27e5816b653492980d20380", - "sha256:09a77e5b2e8ca732a19a90c5bca2d124621a1edb5438c5daa2d2738bfeb02589", - "sha256:130c0f77022b2b9c99d8cebcdd834d81705f61c68e91ddd614ce74c657f8b3ea", - "sha256:141e2aa5ba100d3788c0ad7919b288f89d1fe015878b9659b307c9ef867d3a65", - "sha256:28cb2c41f131a5758d6ba6a0504150d644054fd9f3203a1e8e8d7ac3aea7f73a", - "sha256:2f9f14185962e6a04ab32d1abe34eae8a9001569ee4edb64d2304bf0d65c53f3", - "sha256:320948ab49883557a256eab46149df79435a22d2fefd6a66fe6946f1b9d9d008", - "sha256:36d4b7c4be6411f58f60d9ce555a73df8406d484ba12a63549c88bd64f7967f1", - "sha256:3b15c678f27d66d247132cbf13df2f75255627bcc9b6a570f7d2fd08e8c081d2", - "sha256:3dbd37e14ce795b4af61b89b037d4bc157f2cb23e676fa16932185a04dfbf635", - "sha256:4383b47f45b14459cab66048d384614019965ba6c1a1a141f11b5a551cace1b2", - "sha256:44c95c0e96b3cb628e8452ec060413a49002a247b2b9938989e23a2c8291fc90", - "sha256:4b063d3413f853e056161eb0c7724822a9740ad3caa24b8424d776cebf98e7ee", - "sha256:52ed9ebf8ac602385126c9a2fe951db36f2cb0c2538d22971487f89d0de4065a", - "sha256:55d1580e2d7e17f45d19d3b12098e352f3a37fe86d380bf45846ef257054b242", - "sha256:5ef9bc3d046ce83c4bbf4c25e1e0547b9c441c01d30922d812e887dc5f125c12", - "sha256:5fa82a26f92871eca593b53359c12ad7949772462f887c35edaf36f87953c0e2", - "sha256:61321672b3ac7aade25c40449ccedbc6db72c7f5f0fdf34def5e2f8b51ca530d", - "sha256:701171f825dcab90969596ce2af253143b93b08f1a716d4b2a9d2db5084ef7be", - "sha256:841ec8af7a8491ac76ec5a9522226e287187a3107e12b7d686ad354bb78facee", - "sha256:8a06641fb07d4e8f6c7dda4fc3f8871d327803ab6542e33831c7ccfdcb4d0ad6", - "sha256:8e88bb9eafbf6a4014d55fb222e7360eef53e613215085e65a13290577394529", - "sha256:a00aee5d1b6c20620161984f8ab2ab69134466c51f58c052c11b076715e72929", - "sha256:a047682d324ba56e61b7ea7c7299d51e61fd3bca7dad2ccc39b72bd0118d60a1", - "sha256:a7ef8dd0bf2e1d0a27042b231a3baac6883cdd5557036f5e8df7139255feaac6", - "sha256:ad28cff53f60d99a928dfcf1e861e0b2ceb2bc1f08a074fdd601b314e1cc9e0a", - "sha256:b9097a208875fc7bbeb1286d0125d90bdfed961f61f214d3f5be62cd4ed8a446", - "sha256:b97fe7d7991c25e6a31e5d5e795986b18fbbb3107b873d5f3ae6dc9a103278e9", - "sha256:e0ec52ba3c7f1b7d813cd52649a5b3ef1fc0d433219dc8c93827c57eab6cf888", - "sha256:ea2c3ffb662fec8bbbfce5602e2c159ff097a4631d96235fcf0fb00e59e3ece4", - "sha256:fa3dec4ba8fb6e662770b74f62f1a0c7d4e37e25b58b2bf2c1be4c95372b4a33", - "sha256:fbeb725c9dc799a574518109336acccaf1303c30d45c075c665c0793c2f79a7f" + "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494", + "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806", + "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d", + "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062", + "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2", + "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4", + "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1", + "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85", + "sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84", + "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042", + "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d", + "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962", + "sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2", + "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa", + "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d", + "sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365", + "sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96", + "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47", + "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d", + "sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d", + "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c", + "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb", + "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277", + "sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172", + "sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034", + "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a", + "sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289" ], "markers": "python_version >= '3.7'", - "version": "==42.0.2" + "version": "==43.0.1" }, "docker": { "hashes": [ - "sha256:12ba681f2777a0ad28ffbcc846a69c31b4dfd9752b47eb425a274ee269c5e14b", - "sha256:323736fb92cd9418fc5e7133bc953e11a9da04f4483f828b527db553f1e7e5a3" + "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", + "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0" ], "markers": "python_version >= '3.8'", - "version": "==7.0.0" + "version": "==7.1.0" }, "execnet": { "hashes": [ - "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41", - "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af" + "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", + "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3" ], - "markers": "python_version >= '3.7'", - "version": "==2.0.2" + "markers": "python_version >= '3.8'", + "version": "==2.1.1" }, "idna": { "hashes": [ - "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca", - "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f" + "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", + "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3" ], - "markers": "python_version >= '3.5'", - "version": "==3.6" + "markers": "python_version >= '3.6'", + "version": "==3.10" }, "iniconfig": { "hashes": [ @@ -250,56 +260,58 @@ }, "packaging": { "hashes": [ - "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", - "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7" + "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002", + "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124" ], - "markers": "python_version >= '3.7'", - "version": "==23.2" + "markers": "python_version >= '3.8'", + "version": "==24.1" }, "pluggy": { "hashes": [ - "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981", - "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be" + "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", + "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669" ], "markers": "python_version >= '3.8'", - "version": "==1.4.0" + "version": "==1.5.0" }, "psutil": { "hashes": [ - "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d", - "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73", - "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8", - "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2", - "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e", - "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36", - "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7", - "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c", - "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee", - "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421", - "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf", - "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81", - "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0", - "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631", - "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4", - "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8" + "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", + "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0", + "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c", + "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", + "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3", + "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c", + "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", + "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3", + "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", + "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", + "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6", + "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d", + "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c", + "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", + "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132", + "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14", + "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", - "version": "==5.9.8" + "version": "==6.0.0" }, "pycparser": { "hashes": [ - "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9", - "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206" + "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", + "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc" ], - "version": "==2.21" + "markers": "python_version >= '3.8'", + "version": "==2.22" }, "pytest": { "hashes": [ - "sha256:249b1b0864530ba251b7438274c4d251c58d868edaaec8762893ad4a0d71c36c", - "sha256:50fb9cbe836c3f20f0dfa99c565201fb75dc54c8d76373cd1bde06b06657bdb6" + "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", + "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2" ], "markers": "python_version >= '3.8'", - "version": "==8.0.0" + "version": "==8.3.3" }, "pytest-cs": { "git": "https://github.com/crowdsecurity/pytest-cs.git", @@ -327,6 +339,7 @@ "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24" ], "index": "pypi", + "markers": "python_version >= '3.7'", "version": "==3.5.0" }, "python-dotenv": { @@ -339,68 +352,70 @@ }, "pyyaml": { "hashes": [ - "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5", - "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc", - "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df", - "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741", - "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206", - "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27", - "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595", - "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62", - "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98", - "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696", - "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290", - "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9", - "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d", - "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6", - "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867", - "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47", - "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486", - "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6", - "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3", - "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007", - "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938", - "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0", - "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c", - "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735", - "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d", - "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28", - "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4", - "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba", - "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8", - "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef", - "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5", - "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd", - "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3", - "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0", - "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515", - "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c", - "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c", - "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924", - "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34", - "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43", - "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859", - "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673", - "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54", - "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a", - "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b", - "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab", - "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa", - "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c", - "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585", - "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d", - "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f" + "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff", + "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", + "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", + "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", + "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", + "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", + "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", + "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", + "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", + "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", + "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a", + "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", + "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", + "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", + "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", + "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", + "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", + "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a", + "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", + "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", + "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", + "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631", + "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d", + "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", + "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", + "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", + "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", + "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", + "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", + "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706", + "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", + "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", + "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", + "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083", + "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", + "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", + "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", + "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f", + "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", + "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", + "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", + "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", + "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", + "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", + "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5", + "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d", + "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", + "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", + "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", + "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", + "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", + "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", + "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4" ], - "markers": "python_version >= '3.6'", - "version": "==6.0.1" + "markers": "python_version >= '3.8'", + "version": "==6.0.2" }, "requests": { "hashes": [ - "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f", - "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1" + "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", + "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6" ], - "markers": "python_version >= '3.7'", - "version": "==2.31.0" + "markers": "python_version >= '3.8'", + "version": "==2.32.3" }, "trustme": { "hashes": [ @@ -412,11 +427,11 @@ }, "urllib3": { "hashes": [ - "sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20", - "sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224" + "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", + "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9" ], "markers": "python_version >= '3.8'", - "version": "==2.2.0" + "version": "==2.2.3" } }, "develop": { @@ -437,11 +452,11 @@ }, "executing": { "hashes": [ - "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147", - "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc" + "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf", + "sha256:8ea27ddd260da8150fa5a708269c4a10e76161e2496ec3e587da9e3c0fe4b9ab" ], - "markers": "python_version >= '3.5'", - "version": "==2.0.1" + "markers": "python_version >= '3.8'", + "version": "==2.1.0" }, "gnureadline": { "hashes": [ @@ -482,15 +497,16 @@ "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726" ], "index": "pypi", + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", "version": "==0.13.13" }, "ipython": { "hashes": [ - "sha256:1050a3ab8473488d7eee163796b02e511d0735cf43a04ba2a8348bd0f2eaf8a5", - "sha256:48fbc236fbe0e138b88773fa0437751f14c3645fb483f1d4c5dee58b37e5ce73" + "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a", + "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35" ], "markers": "python_version >= '3.11'", - "version": "==8.21.0" + "version": "==8.28.0" }, "jedi": { "hashes": [ @@ -502,35 +518,35 @@ }, "matplotlib-inline": { "hashes": [ - "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311", - "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304" + "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", + "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca" ], - "markers": "python_version >= '3.5'", - "version": "==0.1.6" + "markers": "python_version >= '3.8'", + "version": "==0.1.7" }, "parso": { "hashes": [ - "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0", - "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75" + "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", + "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d" ], "markers": "python_version >= '3.6'", - "version": "==0.8.3" + "version": "==0.8.4" }, "pexpect": { "hashes": [ "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f" ], - "markers": "sys_platform != 'win32'", + "markers": "sys_platform != 'win32' and sys_platform != 'emscripten'", "version": "==4.9.0" }, "prompt-toolkit": { "hashes": [ - "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d", - "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6" + "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90", + "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e" ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.0.43" + "version": "==3.0.48" }, "ptyprocess": { "hashes": [ @@ -541,18 +557,18 @@ }, "pure-eval": { "hashes": [ - "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350", - "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3" + "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", + "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42" ], - "version": "==0.2.2" + "version": "==0.2.3" }, "pygments": { "hashes": [ - "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c", - "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367" + "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", + "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a" ], - "markers": "python_version >= '3.7'", - "version": "==2.17.2" + "markers": "python_version >= '3.8'", + "version": "==2.18.0" }, "six": { "hashes": [ @@ -571,11 +587,11 @@ }, "traitlets": { "hashes": [ - "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74", - "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e" + "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", + "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f" ], "markers": "python_version >= '3.8'", - "version": "==5.14.1" + "version": "==5.14.3" }, "wcwidth": { "hashes": [ diff --git a/docker/test/default.env b/docker/test/default.env index c46fdab7f1d..9607c8aaa5b 100644 --- a/docker/test/default.env +++ b/docker/test/default.env @@ -6,7 +6,7 @@ CROWDSEC_TEST_VERSION="dev" # All of the following flavors will be tested when using the "flavor" fixture CROWDSEC_TEST_FLAVORS="full" # CROWDSEC_TEST_FLAVORS="full,slim,debian" -# CROWDSEC_TEST_FLAVORS="full,slim,debian,geoip,plugins-debian-slim,debian-geoip,debian-plugins" +# CROWDSEC_TEST_FLAVORS="full,slim,debian,debian-slim" # network to use CROWDSEC_TEST_NETWORK="net-test" diff --git a/docker/test/tests/test_bouncer.py b/docker/test/tests/test_bouncer.py index 1324c3bd38c..98b86de858c 100644 --- a/docker/test/tests/test_bouncer.py +++ b/docker/test/tests/test_bouncer.py @@ -36,8 +36,6 @@ def test_register_bouncer_env(crowdsec, flavor): bouncer1, bouncer2 = j assert bouncer1['name'] == 'bouncer1name' assert bouncer2['name'] == 'bouncer2name' - assert bouncer1['api_key'] == hex512('bouncer1key') - assert bouncer2['api_key'] == hex512('bouncer2key') # add a second bouncer at runtime res = cs.cont.exec_run('cscli bouncers add bouncer3name -k bouncer3key') @@ -48,7 +46,6 @@ def test_register_bouncer_env(crowdsec, flavor): assert len(j) == 3 bouncer3 = j[2] assert bouncer3['name'] == 'bouncer3name' - assert bouncer3['api_key'] == hex512('bouncer3key') # remove all bouncers res = cs.cont.exec_run('cscli bouncers delete bouncer1name bouncer2name bouncer3name') diff --git a/docker/test/tests/test_flavors.py b/docker/test/tests/test_flavors.py index 223cf995cba..7e78b8d681b 100644 --- a/docker/test/tests/test_flavors.py +++ b/docker/test/tests/test_flavors.py @@ -42,7 +42,7 @@ def test_flavor_content(crowdsec, flavor): x = cs.cont.exec_run( 'ls -1 /usr/local/lib/crowdsec/plugins/') stdout = x.output.decode() - if 'slim' in flavor or 'geoip' in flavor: + if 'slim' in flavor: # the exact return code and full message depend # on the 'ls' implementation (busybox vs coreutils) assert x.exit_code != 0 diff --git a/docker/test/tests/test_tls.py b/docker/test/tests/test_tls.py index 591afe0d303..d2f512fcbc1 100644 --- a/docker/test/tests/test_tls.py +++ b/docker/test/tests/test_tls.py @@ -22,8 +22,7 @@ def test_missing_key_file(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: - # XXX: this message appears twice, is that normal? - cs.wait_for_log("*while starting API server: missing TLS key file*") + cs.wait_for_log("*local API server stopped with error: missing TLS key file*") def test_missing_cert_file(crowdsec, flavor): @@ -35,7 +34,7 @@ def test_missing_cert_file(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: - cs.wait_for_log("*while starting API server: missing TLS cert file*") + cs.wait_for_log("*local API server stopped with error: missing TLS cert file*") def test_tls_missing_ca(crowdsec, flavor, certs_dir): @@ -282,7 +281,7 @@ def test_tls_client_ou(crowdsec, flavor, certs_dir): lapi.wait_for_http(8080, '/health', want_status=None) with cs_agent as agent: lapi.wait_for_log([ - "*client certificate OU (?custom-client-ou?) doesn't match expected OU (?agent-ou?)*", + "*client certificate OU ?custom-client-ou? doesn't match expected OU ?agent-ou?*", ]) lapi_env['AGENTS_ALLOWED_OU'] = 'custom-client-ou' diff --git a/go.mod b/go.mod index d61c191c14f..f28f21c6eb4 100644 --- a/go.mod +++ b/go.mod @@ -1,43 +1,42 @@ module github.com/crowdsecurity/crowdsec -go 1.21 +go 1.22 // Don't use the toolchain directive to avoid uncontrolled downloads during // a build, especially in sandboxed environments (freebsd, gentoo...). // toolchain go1.21.3 require ( - entgo.io/ent v0.12.4 + entgo.io/ent v0.13.1 github.com/AlecAivazis/survey/v2 v2.3.7 github.com/Masterminds/semver/v3 v3.2.1 github.com/Masterminds/sprig/v3 v3.2.3 - github.com/agext/levenshtein v1.2.1 + github.com/agext/levenshtein v1.2.3 github.com/alexliesenfeld/health v0.8.0 - github.com/antonmedv/expr v1.15.3 - github.com/appleboy/gin-jwt/v2 v2.8.0 - github.com/aquasecurity/table v1.8.0 - github.com/aws/aws-lambda-go v1.41.0 - github.com/aws/aws-sdk-go v1.48.15 - github.com/beevik/etree v1.1.0 - github.com/blackfireio/osinfo v1.0.3 + github.com/appleboy/gin-jwt/v2 v2.9.2 + github.com/aws/aws-lambda-go v1.47.0 + github.com/aws/aws-sdk-go v1.52.0 + github.com/beevik/etree v1.4.1 + github.com/blackfireio/osinfo v1.0.5 github.com/bluele/gcache v0.0.2 github.com/buger/jsonparser v1.1.1 - github.com/c-robinson/iplib v1.0.3 - github.com/cespare/xxhash/v2 v2.2.0 + github.com/c-robinson/iplib v1.0.8 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/corazawaf/libinjection-go v0.1.2 github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 - github.com/crowdsecurity/go-cs-lib v0.0.6 - github.com/crowdsecurity/grokky v0.2.1 + github.com/crowdsecurity/go-cs-lib v0.0.15 + github.com/crowdsecurity/grokky v0.2.2 github.com/crowdsecurity/machineid v1.0.2 github.com/davecgh/go-spew v1.1.1 - github.com/dghubble/sling v1.3.0 - github.com/docker/docker v24.0.7+incompatible + github.com/dghubble/sling v1.4.2 + github.com/docker/docker v24.0.9+incompatible github.com/docker/go-connections v0.4.0 - github.com/enescakir/emoji v1.0.0 - github.com/fatih/color v1.15.0 - github.com/fsnotify/fsnotify v1.6.0 + github.com/expr-lang/expr v1.16.9 + github.com/fatih/color v1.16.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/gin-gonic/gin v1.9.1 - github.com/go-co-op/gocron v1.17.0 + github.com/go-co-op/gocron v1.37.0 github.com/go-openapi/errors v0.20.1 github.com/go-openapi/strfmt v0.19.11 github.com/go-openapi/swag v0.22.3 @@ -46,8 +45,8 @@ require ( github.com/goccy/go-yaml v1.11.0 github.com/gofrs/uuid v4.0.0+incompatible github.com/golang-jwt/jwt/v4 v4.5.0 - github.com/google/go-querystring v1.0.0 - github.com/google/uuid v1.3.0 + github.com/google/go-querystring v1.1.0 + github.com/google/uuid v1.6.0 github.com/google/winops v0.0.0-20230712152054-af9b550d0601 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e github.com/gorilla/websocket v1.5.0 @@ -56,16 +55,17 @@ require ( github.com/hashicorp/go-version v1.2.1 github.com/hexops/gotextdiff v1.0.3 github.com/ivanpirog/coloredcobra v1.0.1 - github.com/jackc/pgx/v4 v4.14.1 + github.com/jackc/pgx/v4 v4.18.2 github.com/jarcoal/httpmock v1.1.0 + github.com/jedib0t/go-pretty/v6 v6.5.9 github.com/jszwec/csvutil v1.5.1 github.com/lithammer/dedent v1.1.0 - github.com/mattn/go-isatty v0.0.19 + github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-sqlite3 v1.14.16 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/nxadm/tail v1.4.8 - github.com/oschwald/geoip2-golang v1.4.0 - github.com/oschwald/maxminddb-golang v1.8.0 + github.com/oschwald/geoip2-golang v1.9.0 + github.com/oschwald/maxminddb-golang v1.12.0 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.16.0 @@ -77,42 +77,41 @@ require ( github.com/shirou/gopsutil/v3 v3.23.5 github.com/sirupsen/logrus v1.9.3 github.com/slack-go/slack v0.12.2 - github.com/spf13/cobra v1.7.0 - github.com/stretchr/testify v1.8.4 + github.com/spf13/cobra v1.8.0 + github.com/stretchr/testify v1.9.0 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 - github.com/wasilibs/go-re2 v1.3.0 + github.com/wasilibs/go-re2 v1.7.0 github.com/xhit/go-simple-mail/v2 v2.16.0 - golang.org/x/crypto v0.17.0 - golang.org/x/mod v0.11.0 - golang.org/x/sys v0.15.0 - golang.org/x/text v0.14.0 - google.golang.org/grpc v1.56.3 - google.golang.org/protobuf v1.31.0 + golang.org/x/crypto v0.26.0 + golang.org/x/mod v0.17.0 + golang.org/x/sys v0.24.0 + golang.org/x/text v0.17.0 + google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 - gotest.tools/v3 v3.5.0 k8s.io/apiserver v0.28.4 ) require ( - ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 // indirect + ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bytedance/sonic v1.9.1 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect - github.com/corazawaf/libinjection-go v0.1.2 // indirect + github.com/bytedance/sonic v1.10.2 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect + github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/creack/pty v1.1.18 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-ole/go-ole v1.2.6 // indirect @@ -125,13 +124,13 @@ require ( github.com/go-openapi/spec v0.20.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-playground/validator/v10 v10.17.0 // indirect github.com/go-stack/stack v1.8.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/glog v1.1.0 // indirect + github.com/golang/glog v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect @@ -139,24 +138,24 @@ require ( github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.10.1 // indirect + github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgproto3/v2 v2.2.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect - github.com/jackc/pgtype v1.9.1 // indirect + github.com/jackc/pgproto3/v2 v2.3.3 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.17.3 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/leodido/go-urn v1.3.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect - github.com/magefile/mage v1.15.0 // indirect + github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-runewidth v0.0.13 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mitchellh/copystructure v1.2.0 // indirect @@ -171,7 +170,7 @@ require ( github.com/oklog/run v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -186,7 +185,7 @@ require ( github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/tetratelabs/wazero v1.2.1 // indirect + github.com/tetratelabs/wazero v1.8.0 // indirect github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -194,22 +193,25 @@ require ( github.com/tklauser/numcpus v0.6.0 // indirect github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect + github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/zclconf/go-cty v1.8.0 // indirect go.mongodb.org/mongo-driver v1.9.4 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/term v0.15.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + golang.org/x/arch v0.7.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/term v0.23.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + gotest.tools/v3 v3.5.0 // indirect k8s.io/api v0.28.4 // indirect k8s.io/apimachinery v0.28.4 // indirect k8s.io/klog/v2 v2.100.1 // indirect diff --git a/go.sum b/go.sum index f5f61594ecd..b2bd77c9915 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,9 @@ -ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 h1:JnYs/y8RJ3+MiIUp+3RgyyeO48VHLAZimqiaZYnMKk8= -ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935/go.mod h1:isZrlzJ5cpoCoKFoY9knZug7Lq4pP1cm8g3XciLZ0Pw= +ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 h1:GwdJbXydHCYPedeeLt4x/lrlIISQ4JTH1mRWuE5ZZ14= +ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43/go.mod h1:uj3pm+hUTVN/X5yfdBexHlZv+1Xu5u5ZbZx7+CDavNU= bitbucket.org/creachadair/stringset v0.0.9 h1:L4vld9nzPt90UZNrXjNelTshD74ps4P5NGs3Iq6yN3o= bitbucket.org/creachadair/stringset v0.0.9/go.mod h1:t+4WcQ4+PXTa8aQdNKe40ZP6iwesoMFWAxPGd3UGjyY= -entgo.io/ent v0.12.4 h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8= -entgo.io/ent v0.12.4/go.mod h1:Y3JVAjtlIk8xVZYSn3t3mf8xlZIn5SAOXZQxD6kKI+Q= +entgo.io/ent v0.13.1 h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE= +entgo.io/ent v0.13.1/go.mod h1:qCEmo+biw3ccBn9OyL4ZK5dfpwg++l1Gxwac5B1206A= github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= @@ -26,8 +26,8 @@ github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDe github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= -github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= -github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= +github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:pzStYMLAXM7CNQjS/Wn+zK9MUxDhSUNfVvnHsyQyjs0= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ilK+u7u1HoqaDk0mjhh27QJB7PyWMreGffEvOCoEKiY= @@ -39,49 +39,52 @@ github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk5 github.com/alexliesenfeld/health v0.8.0 h1:lCV0i+ZJPTbqP7LfKG7p3qZBl5VhelwUFCIVWl77fgk= github.com/alexliesenfeld/health v0.8.0/go.mod h1:TfNP0f+9WQVWMQRzvMUjlws4ceXKEL3WR+6Hp95HUFc= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= -github.com/antonmedv/expr v1.15.3 h1:q3hOJZNvLvhqE8OHBs1cFRdbXFNKuA+bHmRaI+AmRmI= -github.com/antonmedv/expr v1.15.3/go.mod h1:0E/6TxnOlRNp81GMzX9QfDPAmHo2Phg00y4JUv1ihsE= github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= -github.com/appleboy/gin-jwt/v2 v2.8.0 h1:Glo7cb9eBR+hj8Y7WzgfkOlqCaNLjP+RV4dNO3fpdps= -github.com/appleboy/gin-jwt/v2 v2.8.0/go.mod h1:KsK7E8HTvRg3vOiumTsr/ntNTHbZ3IbHLe4Eto31p7k= +github.com/appleboy/gin-jwt/v2 v2.9.2 h1:GeS3lm9mb9HMmj7+GNjYUtpp3V1DAQ1TkUFa5poiZ7Y= +github.com/appleboy/gin-jwt/v2 v2.9.2/go.mod h1:mxGjKt9Lrx9Xusy1SrnmsCJMZG6UJwmdHN9bN27/QDw= github.com/appleboy/gofight/v2 v2.1.2 h1:VOy3jow4vIK8BRQJoC/I9muxyYlJ2yb9ht2hZoS3rf4= github.com/appleboy/gofight/v2 v2.1.2/go.mod h1:frW+U1QZEdDgixycTj4CygQ48yLTUhplt43+Wczp3rw= -github.com/aquasecurity/table v1.8.0 h1:9ntpSwrUfjrM6/YviArlx/ZBGd6ix8W+MtojQcM7tv0= -github.com/aquasecurity/table v1.8.0/go.mod h1:eqOmvjjB7AhXFgFqpJUEE/ietg7RrMSJZXyTN8E/wZw= github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/aws/aws-lambda-go v1.41.0 h1:l/5fyVb6Ud9uYd411xdHZzSf2n86TakxzpvIoz7l+3Y= -github.com/aws/aws-lambda-go v1.41.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= +github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= +github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= -github.com/aws/aws-sdk-go v1.48.15 h1:Gad2C4pLzuZDd5CA0Rvkfko6qUDDTOYru145gkO7w/Y= -github.com/aws/aws-sdk-go v1.48.15/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= -github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= -github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/aws/aws-sdk-go v1.52.0 h1:ptgek/4B2v/ljsjYSEvLQ8LTD+SQyrqhOOWvHc/VGPI= +github.com/aws/aws-sdk-go v1.52.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/beevik/etree v1.3.0 h1:hQTc+pylzIKDb23yYprodCWWTt+ojFfUZyzU09a/hmU= +github.com/beevik/etree v1.3.0/go.mod h1:aiPf89g/1k3AShMVAzriilpcE4R/Vuor90y83zVZWFc= +github.com/beevik/etree v1.4.1 h1:PmQJDDYahBGNKDcpdX8uPy1xRCwoCGVUiW669MEirVI= +github.com/beevik/etree v1.4.1/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/blackfireio/osinfo v1.0.3 h1:Yk2t2GTPjBcESv6nDSWZKO87bGMQgO+Hi9OoXPpxX8c= -github.com/blackfireio/osinfo v1.0.3/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= +github.com/blackfireio/osinfo v1.0.5 h1:6hlaWzfcpb87gRmznVf7wSdhysGqLRz9V/xuSdCEXrA= +github.com/blackfireio/osinfo v1.0.5/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU= -github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= +github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= +github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= +github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4= +github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= +github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0= +github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/corazawaf/libinjection-go v0.1.2 h1:oeiV9pc5rvJ+2oqOqXEAMJousPpGiup6f7Y3nZj5GoM= @@ -91,64 +94,59 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f h1:FkOB9aDw0xzDd14pTarGRLsUNAymONq3dc7zhvsXElg= -github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f/go.mod h1:TrU7Li+z2RHNrPy0TKJ6R65V6Yzpan2sTIRryJJyJso= github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI= github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= -github.com/crowdsecurity/go-cs-lib v0.0.5 h1:eVLW+BRj3ZYn0xt5/xmgzfbbB8EBo32gM4+WpQQk2e8= -github.com/crowdsecurity/go-cs-lib v0.0.5/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= -github.com/crowdsecurity/go-cs-lib v0.0.6 h1:Ef6MylXe0GaJE9vrfvxEdbHb31+JUP1os+murPz7Pos= -github.com/crowdsecurity/go-cs-lib v0.0.6/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= -github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4= -github.com/crowdsecurity/grokky v0.2.1/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= +github.com/crowdsecurity/go-cs-lib v0.0.15 h1:zNWqOPVLHgKUstlr6clom9d66S0eIIW66jQG3Y7FEvo= +github.com/crowdsecurity/go-cs-lib v0.0.15/go.mod h1:ePyQyJBxp1W/1bq4YpVAilnLSz7HkzmtI7TRhX187EU= +github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4= +github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc= github.com/crowdsecurity/machineid v1.0.2/go.mod h1:XWUSlnS0R0+u/JK5ulidwlbceNT3ZOCKteoVQEn6Luo= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dghubble/sling v1.3.0 h1:pZHjCJq4zJvc6qVQ5wN1jo5oNZlNE0+8T/h0XeXBUKU= -github.com/dghubble/sling v1.3.0/go.mod h1:XXShWaBWKzNLhu2OxikSNFrlsvowtz4kyRuXUG7oQKY= +github.com/dghubble/sling v1.4.2 h1:vs1HIGBbSl2SEALyU+irpYFLZMfc49Fp+jYryFebQjM= +github.com/dghubble/sling v1.4.2/go.mod h1:o0arCOz0HwfqYQJLrRtqunaWOn4X6jxE/6ORKRpVTD4= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM= -github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0= +github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/enescakir/emoji v1.0.0 h1:W+HsNql8swfCQFtioDGDHCHri8nudlK1n5p2rHCJoog= -github.com/enescakir/emoji v1.0.0/go.mod h1:Bt1EKuLnKDTYpLALApstIkAjdDrS/8IAgTkKp+WKFD0= +github.com/expr-lang/expr v1.16.9 h1:WUAzmR0JNI9JCiF0/ewwHB1gmcGw5wW7nWt8gc6PpCI= +github.com/expr-lang/expr v1.16.9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/foxcpp/go-mockdns v1.0.0 h1:7jBqxd3WDWwi/6WhDvacvH1XsN3rOLXyHM1uhvIx6FI= github.com/foxcpp/go-mockdns v1.0.0/go.mod h1:lgRN6+KxQBawyIghpnl5CezHFGS9VLzvtVlwxvzXTQ4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= -github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= -github.com/go-co-op/gocron v1.17.0 h1:IixLXsti+Qo0wMvmn6Kmjp2csk2ykpkcL+EmHmST18w= -github.com/go-co-op/gocron v1.17.0/go.mod h1:IpDBSaJOVfFw7hXZuTag3SCSkqazXBBUkbQ1m1aesBs= +github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0= +github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -247,18 +245,14 @@ github.com/go-openapi/validate v0.19.12/go.mod h1:Rzou8hA/CBw8donlS6WNEUQupNvUZ0 github.com/go-openapi/validate v0.19.15/go.mod h1:tbn/fdOwYHgrhPBzidZfJC2MIVvs9GA7monOmWBbeCI= github.com/go-openapi/validate v0.20.0 h1:pzutNCCBZGZlE+u8HD3JZyWdc/TVbtVwlWUp8/vgUKk= github.com/go-openapi/validate v0.20.0/go.mod h1:b60iJT+xNNLfaQJUqLI7946tYiFEOuE9E4k54HpKcJ0= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.17.0 h1:SmVVlfAOtlZncTxRuinDPomC2DkXJ4E5T9gDA0AIH74= +github.com/go-playground/validator/v10 v10.17.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -300,15 +294,13 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= -github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ= +github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= +github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= @@ -321,18 +313,20 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/winops v0.0.0-20230712152054-af9b550d0601 h1:XvlrmqZIuwxuRE88S9mkxX+FkV+YakqbiAC5Z4OzDnM= github.com/google/winops v0.0.0-20230712152054-af9b550d0601/go.mod h1:rT1mcjzuvcDDbRmUTsoH6kV0DG91AkFe9UCjASraK5I= github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e h1:XmA6L9IPRdUr28a+SK/oMchGgQy159wvzXA5tJ7l+40= @@ -374,8 +368,8 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= -github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= +github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -391,28 +385,30 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvW github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= -github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= -github.com/jackc/pgtype v1.9.1 h1:MJc2s0MFS8C3ok1wQTdQxWuXQcB6+HwAm5x1CzW7mf0= -github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= -github.com/jackc/pgx/v4 v4.14.1 h1:71oo1KAGI6mXhLiTMn6iDFcp3e7+zon/capWjl2OEFU= -github.com/jackc/pgx/v4 v4.14.1/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= +github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= +github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jarcoal/httpmock v1.1.0 h1:F47ChZj1Y2zFsCXxNkBPwNNKnAyOATcdQibk0qEdVCE= github.com/jarcoal/httpmock v1.1.0/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= +github.com/jedib0t/go-pretty/v6 v6.5.9 h1:ACteMBRrrmm1gMsXe9PSTOClQ63IXDUt03H5U+UV8OU= +github.com/jedib0t/go-pretty/v6 v6.5.9/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -423,7 +419,6 @@ github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqx github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jszwec/csvutil v1.5.1 h1:c3GFBhj6DFMUl4dMK3+B6rz2+LWWS/e9VJiVJ9t9kfQ= @@ -442,13 +437,15 @@ github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHU github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -457,11 +454,10 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= -github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.3.0 h1:jX8FDLfW4ThVXctBNZ+3cIWnCSnrACDV73r76dy0aQQ= +github.com/leodido/go-urn v1.3.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -471,8 +467,8 @@ github.com/lithammer/dedent v1.1.0 h1:VNzHMVCBNG1j0fh3OrsFRkVUwStdDArbgBWoPAffkt github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= -github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= -github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho= +github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -496,10 +492,10 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= -github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -531,7 +527,6 @@ github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= @@ -550,23 +545,23 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= -github.com/oschwald/geoip2-golang v1.4.0 h1:5RlrjCgRyIGDz/mBmPfnAF4h8k0IAcRv9PvrpOfz+Ug= -github.com/oschwald/geoip2-golang v1.4.0/go.mod h1:8QwxJvRImBH+Zl6Aa6MaIcs5YdlZSTKtzmPGzQqi9ng= -github.com/oschwald/maxminddb-golang v1.6.0/go.mod h1:DUJFucBg2cvqx42YmDa/+xHvb0elJtOm3o4aFQ/nb/w= -github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk= -github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= +github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc= +github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y= +github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= +github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= +github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e h1:POJco99aNgosh92lGqmx7L1ei+kCymivB/419SD15PQ= github.com/petar-dambovaliev/aho-corasick v0.0.0-20230725210150-fb29fc3c913e/go.mod h1:EHPiTAKtiFmrMldLUNswFwfZ2eJIYBHktdaUTZxYWRw= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -604,6 +599,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= @@ -640,8 +637,8 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -649,8 +646,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -664,11 +662,11 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs= -github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= -github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= +github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -685,10 +683,8 @@ github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 h1:PM5hJF7HVfNWmCjM github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 h1:UFHFmFfixpmfRBcxuu+LA9l8MdURWVdVNUHxO5n1d2w= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26/go.mod h1:IGhd0qMDsUa9acVjsbsT7bu3ktadtGOHI79+idTew/M= github.com/vektah/gqlparser v1.1.2/go.mod h1:1ycwN7Ij5njmMkPPAOaRFY4rET2Enx7IkVv3vaXspKw= @@ -698,10 +694,12 @@ github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaU github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= github.com/vmihailenco/msgpack/v4 v4.3.12/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -github.com/wasilibs/go-re2 v1.3.0 h1:LFhBNzoStM3wMie6rN2slD1cuYH2CGiHpvNL3UtcsMw= -github.com/wasilibs/go-re2 v1.3.0/go.mod h1:AafrCXVvGRJJOImMajgJ2M7rVmWyisVK7sFshbxnVrg= +github.com/wasilibs/go-re2 v1.7.0 h1:bYhl8gn+a9h01dxwotNycxkiFPTiSgwUrIz8KZJ90Lc= +github.com/wasilibs/go-re2 v1.7.0/go.mod h1:sUsZMLflgl+LNivDE229omtmvjICmOseT9xOy199VDU= github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ= github.com/wasilibs/nottinygc v0.4.0/go.mod h1:oDcIotskuYNMpqMF23l7Z8uzD4TC0WXHK8jetlB3HIo= +github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 h1:OvLBa8SqJnZ6P+mjlzc2K7PM22rRUPE1x32G9DTPrC4= +github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52/go.mod h1:jMeV4Vpbi8osrE/pKUxRZkVaA0EX7NZN0A9/oRzgpgY= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= @@ -735,6 +733,9 @@ go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -743,8 +744,8 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= +golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -764,8 +765,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -773,8 +774,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= -golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -798,8 +799,8 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -809,10 +810,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -829,7 +828,6 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -841,18 +839,16 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -860,8 +856,8 @@ golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -874,8 +870,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -899,8 +895,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 h1:0wxTF6pSjIIhNt7mo9GvjDfzyCOiWhmICgtO/Ah948s= -golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -912,14 +908,14 @@ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= -google.golang.org/grpc v1.56.3 h1:8I4C0Yq1EjstUzUJzpcRVbuYA2mODtEmpWiQoN/b2nc= -google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -962,6 +958,7 @@ k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg= k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 h1:qY1Ad8PODbnymg2pRbkyMT/ylpTrCM8P2RJ0yroCyIk= k8s.io/utils v0.0.0-20230406110748-d93618cff8a2/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/make_chocolatey.ps1 b/make_chocolatey.ps1 index 67f85c33d89..cceed28402f 100644 --- a/make_chocolatey.ps1 +++ b/make_chocolatey.ps1 @@ -15,4 +15,6 @@ if ($version.Contains("-")) Set-Location .\windows\Chocolatey\crowdsec Copy-Item ..\..\..\crowdsec_$version.msi tools\crowdsec.msi -choco pack --version $version \ No newline at end of file +choco pack --version $version + +Copy-Item crowdsec.$version.nupkg ..\..\..\ \ No newline at end of file diff --git a/make_installer.ps1 b/make_installer.ps1 index a20ffaf55b5..c927452ff72 100644 --- a/make_installer.ps1 +++ b/make_installer.ps1 @@ -1,7 +1,7 @@ param ( $version ) -$env:Path += ";C:\Program Files (x86)\WiX Toolset v3.11\bin" +$env:Path += ";C:\Program Files (x86)\WiX Toolset v3.14\bin" if ($version.StartsWith("v")) { $version = $version.Substring(1) diff --git a/mk/check_go_version.ps1 b/mk/check_go_version.ps1 deleted file mode 100644 index 6060cb22751..00000000000 --- a/mk/check_go_version.ps1 +++ /dev/null @@ -1,19 +0,0 @@ -##This must be called with $(MINIMUM_SUPPORTED_GO_MAJOR_VERSION) $(MINIMUM_SUPPORTED_GO_MINOR_VERSION) in this order -$min_major=$args[0] -$min_minor=$args[1] -$goversion = (go env GOVERSION).replace("go","").split(".") -$goversion_major=$goversion[0] -$goversion_minor=$goversion[1] -$err_msg="Golang version $goversion_major.$goversion_minor is not supported, please use least $min_major.$min_minor" - -if ( $goversion_major -gt $min_major ) { - exit 0; -} -elseif ($goversion_major -lt $min_major) { - Write-Output $err_msg; - exit 1; -} -elseif ($goversion_minor -lt $min_minor) { - Write-Output $(GO_VERSION_VALIDATION_ERR_MSG); - exit 1; -} diff --git a/mk/goversion.mk b/mk/goversion.mk deleted file mode 100644 index 73e9a72e232..00000000000 --- a/mk/goversion.mk +++ /dev/null @@ -1,36 +0,0 @@ - -BUILD_GOVERSION = $(subst go,,$(shell $(GO) env GOVERSION)) - -go_major_minor = $(subst ., ,$(BUILD_GOVERSION)) -GO_MAJOR_VERSION = $(word 1, $(go_major_minor)) -GO_MINOR_VERSION = $(word 2, $(go_major_minor)) - -GO_VERSION_VALIDATION_ERR_MSG = Golang version ($(BUILD_GOVERSION)) is not supported, please use at least $(BUILD_REQUIRE_GO_MAJOR).$(BUILD_REQUIRE_GO_MINOR) - - -.PHONY: goversion -goversion: $(if $(findstring devel,$(shell $(GO) env GOVERSION)),goversion_devel,goversion_check) - - -.PHONY: goversion_devel -goversion_devel: - $(warning WARNING: You are using a development version of Golang ($(BUILD_GOVERSION)) which is not supported. For production environments, use a stable version (at least $(BUILD_REQUIRE_GO_MAJOR).$(BUILD_REQUIRE_GO_MINOR))) - $(info ) - - -.PHONY: goversion_check -goversion_check: -ifneq ($(OS), Windows_NT) - @if [ $(GO_MAJOR_VERSION) -gt $(BUILD_REQUIRE_GO_MAJOR) ]; then \ - exit 0; \ - elif [ $(GO_MAJOR_VERSION) -lt $(BUILD_REQUIRE_GO_MAJOR) ]; then \ - echo '$(GO_VERSION_VALIDATION_ERR_MSG)';\ - exit 1; \ - elif [ $(GO_MINOR_VERSION) -lt $(BUILD_REQUIRE_GO_MINOR) ] ; then \ - echo '$(GO_VERSION_VALIDATION_ERR_MSG)';\ - exit 1; \ - fi -else - # This needs Set-ExecutionPolicy -Scope CurrentUser Unrestricted - @$(CURDIR)/mk/check_go_version.ps1 $(BUILD_REQUIRE_GO_MAJOR) $(BUILD_REQUIRE_GO_MINOR) -endif diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 33602936369..4519ea7392b 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -1,14 +1,15 @@ package acquisition import ( + "context" "errors" "fmt" "io" "os" "strings" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -18,21 +19,9 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" - appsecacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec" - cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" - dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" - fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" - journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" - kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" - kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" - k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" - lokiacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" - s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" - syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" - wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -51,84 +40,110 @@ func (e *DataSourceUnavailableError) Unwrap() error { // The interface each datasource must implement type DataSource interface { - GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module - GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) - UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime - Configure([]byte, *log.Entry) error // Complete the YAML datasource configuration and perform runtime checks. - ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource - GetMode() string // Get the mode (TAIL, CAT or SERVER) - GetName() string // Get the name of the module - OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) - StreamingAcquisition(chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) - CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) - GetUuid() string // Get the unique identifier of the datasource + GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module + GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) + UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime + Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. + ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource + GetMode() string // Get the mode (TAIL, CAT or SERVER) + GetName() string // Get the name of the module + OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) + StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) + CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) + GetUuid() string // Get the unique identifier of the datasource Dump() interface{} } -var AcquisitionSources = map[string]func() DataSource{ - "file": func() DataSource { return &fileacquisition.FileSource{} }, - "journalctl": func() DataSource { return &journalctlacquisition.JournalCtlSource{} }, - "cloudwatch": func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }, - "syslog": func() DataSource { return &syslogacquisition.SyslogSource{} }, - "docker": func() DataSource { return &dockeracquisition.DockerSource{} }, - "kinesis": func() DataSource { return &kinesisacquisition.KinesisSource{} }, - "wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }, - "kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} }, - "k8s-audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }, - "loki": func() DataSource { return &lokiacquisition.LokiSource{} }, - "s3": func() DataSource { return &s3acquisition.S3Source{} }, - "appsec": func() DataSource { return &appsecacquisition.AppsecSource{} }, +var ( + // We declare everything here so we can tell if they are unsupported, or excluded from the build + AcquisitionSources = map[string]func() DataSource{} + transformRuntimes = map[string]*vm.Program{} +) + +func GetDataSourceIface(dataSourceType string) (DataSource, error) { + source, registered := AcquisitionSources[dataSourceType] + if registered { + return source(), nil + } + + built, known := component.Built["datasource_"+dataSourceType] + + if !known { + return nil, fmt.Errorf("unknown data source %s", dataSourceType) + } + + if built { + panic("datasource " + dataSourceType + " is built but not registered") + } + + return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", dataSourceType) } -var transformRuntimes = map[string]*vm.Program{} +// registerDataSource registers a datasource in the AcquisitionSources map. +// It must be called in the init() function of the datasource package, and the datasource name +// must be declared with a nil value in the map, to allow for conditional compilation. +func registerDataSource(dataSourceType string, dsGetter func() DataSource) { + component.Register("datasource_" + dataSourceType) -func GetDataSourceIface(dataSourceType string) DataSource { - source := AcquisitionSources[dataSourceType] - if source == nil { - return nil + AcquisitionSources[dataSourceType] = dsGetter +} + +// setupLogger creates a logger for the datasource to use at runtime. +func setupLogger(source, name string, level *log.Level) (*log.Entry, error) { + clog := log.New() + if err := types.ConfigureLogger(clog); err != nil { + return nil, fmt.Errorf("while configuring datasource logger: %w", err) + } + + if level != nil { + clog.SetLevel(*level) + } + + fields := log.Fields{ + "type": source, + } + + if name != "" { + fields["name"] = name } - return source() + + subLogger := clog.WithFields(fields) + + return subLogger, nil } // DataSourceConfigure creates and returns a DataSource object from a configuration, // if the configuration is not valid it returns an error. // If the datasource can't be run (eg. journalctl not available), it still returns an error which // can be checked for the appropriate action. -func DataSourceConfigure(commonConfig configuration.DataSourceCommonCfg) (*DataSource, error) { +func DataSourceConfigure(commonConfig configuration.DataSourceCommonCfg, metricsLevel int) (*DataSource, error) { // we dump it back to []byte, because we want to decode the yaml blob twice: // once to DataSourceCommonCfg, and then later to the dedicated type of the datasource yamlConfig, err := yaml.Marshal(commonConfig) if err != nil { - return nil, fmt.Errorf("unable to marshal back interface: %w", err) + return nil, fmt.Errorf("unable to serialize back interface: %w", err) } - if dataSrc := GetDataSourceIface(commonConfig.Source); dataSrc != nil { - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) - } - if commonConfig.LogLevel != nil { - clog.SetLevel(*commonConfig.LogLevel) - } - customLog := log.Fields{ - "type": commonConfig.Source, - } - if commonConfig.Name != "" { - customLog["name"] = commonConfig.Name - } - subLogger := clog.WithFields(customLog) - /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ - if err := dataSrc.CanRun(); err != nil { - return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} - } - /* configure the actual datasource */ - if err := dataSrc.Configure(yamlConfig, subLogger); err != nil { - return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) - } - return &dataSrc, nil + dataSrc, err := GetDataSourceIface(commonConfig.Source) + if err != nil { + return nil, err + } + + subLogger, err := setupLogger(commonConfig.Source, commonConfig.Name, commonConfig.LogLevel) + if err != nil { + return nil, err + } + + /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ + if err := dataSrc.CanRun(); err != nil { + return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} + } + /* configure the actual datasource */ + if err := dataSrc.Configure(yamlConfig, subLogger, metricsLevel); err != nil { + return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) } - return nil, fmt.Errorf("cannot find source %s", commonConfig.Source) + + return &dataSrc, nil } // detectBackwardCompatAcquis: try to magically detect the type for backward compat (type was not mandatory then) @@ -136,12 +151,15 @@ func detectBackwardCompatAcquis(sub configuration.DataSourceCommonCfg) string { if _, ok := sub.Config["filename"]; ok { return "file" } + if _, ok := sub.Config["filenames"]; ok { return "file" } + if _, ok := sub.Config["journalctl_filter"]; ok { return "journalctl" } + return "" } @@ -152,109 +170,160 @@ func LoadAcquisitionFromDSN(dsn string, labels map[string]string, transformExpr if len(frags) == 1 { return nil, fmt.Errorf("%s isn't valid dsn (no protocol)", dsn) } - dataSrc := GetDataSourceIface(frags[0]) - if dataSrc == nil { - return nil, fmt.Errorf("no acquisition for protocol %s://", frags[0]) + + dataSrc, err := GetDataSourceIface(frags[0]) + if err != nil { + return nil, fmt.Errorf("no acquisition for protocol %s:// - %w", frags[0], err) } - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) + + subLogger, err := setupLogger(dsn, "", nil) + if err != nil { + return nil, err } - subLogger := clog.WithFields(log.Fields{ - "type": dsn, - }) + uniqueId := uuid.NewString() + if transformExpr != "" { vm, err := expr.Compile(transformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s': %w", transformExpr, err) } + transformRuntimes[uniqueId] = vm } - err := dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) + + err = dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) if err != nil { return nil, fmt.Errorf("while configuration datasource for %s: %w", dsn, err) } + sources = append(sources, dataSrc) + return sources, nil } +func GetMetricsLevelFromPromCfg(prom *csconfig.PrometheusCfg) int { + if prom == nil { + return configuration.METRICS_FULL + } + + if !prom.Enabled { + return configuration.METRICS_NONE + } + + if prom.Level == configuration.CFG_METRICS_AGGREGATE { + return configuration.METRICS_AGGREGATE + } + + if prom.Level == configuration.CFG_METRICS_FULL { + return configuration.METRICS_FULL + } + + return configuration.METRICS_FULL +} + // LoadAcquisitionFromFile unmarshals the configuration item and checks its availability -func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg) ([]DataSource, error) { +func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig.PrometheusCfg) ([]DataSource, error) { var sources []DataSource + metrics_level := GetMetricsLevelFromPromCfg(prom) + for _, acquisFile := range config.AcquisitionFiles { log.Infof("loading acquisition file : %s", acquisFile) + yamlFile, err := os.Open(acquisFile) if err != nil { return nil, err } + dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) + idx := -1 + for { var sub configuration.DataSourceCommonCfg - err = dec.Decode(&sub) + idx += 1 + + err = dec.Decode(&sub) if err != nil { if !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to yaml decode %s: %w", acquisFile, err) } + log.Tracef("End of yaml file") + break } - //for backward compat ('type' was not mandatory, detect it) + // for backward compat ('type' was not mandatory, detect it) if guessType := detectBackwardCompatAcquis(sub); guessType != "" { sub.Source = guessType } - //it's an empty item, skip it + // it's an empty item, skip it if len(sub.Labels) == 0 { if sub.Source == "" { log.Debugf("skipping empty item in %s", acquisFile) continue } - return nil, fmt.Errorf("missing labels in %s (position: %d)", acquisFile, idx) + + if sub.Source != "docker" { + // docker is the only source that can be empty + return nil, fmt.Errorf("missing labels in %s (position: %d)", acquisFile, idx) + } } + if sub.Source == "" { return nil, fmt.Errorf("data source type is empty ('source') in %s (position: %d)", acquisFile, idx) } - if GetDataSourceIface(sub.Source) == nil { - return nil, fmt.Errorf("unknown data source %s in %s (position: %d)", sub.Source, acquisFile, idx) + + // pre-check that the source is valid + _, err := GetDataSourceIface(sub.Source) + if err != nil { + return nil, fmt.Errorf("in file %s (position: %d) - %w", acquisFile, idx, err) } + uniqueId := uuid.NewString() sub.UniqueId = uniqueId - src, err := DataSourceConfigure(sub) + + src, err := DataSourceConfigure(sub, metrics_level) if err != nil { var dserr *DataSourceUnavailableError if errors.As(err, &dserr) { log.Error(err) continue } + return nil, fmt.Errorf("while configuring datasource of type %s from %s (position: %d): %w", sub.Source, acquisFile, idx, err) } + if sub.TransformExpr != "" { vm, err := expr.Compile(sub.TransformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s' for datasource %s in %s (position: %d): %w", sub.TransformExpr, sub.Source, acquisFile, idx, err) } + transformRuntimes[uniqueId] = vm } + sources = append(sources, *src) } } + return sources, nil } func GetMetrics(sources []DataSource, aggregated bool) error { var metrics []prometheus.Collector - for i := 0; i < len(sources); i++ { + + for i := range sources { if aggregated { metrics = sources[i].GetMetrics() } else { metrics = sources[i].GetAggregMetrics() } + for _, metric := range metrics { if err := prometheus.Register(metric); err != nil { if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { @@ -264,12 +333,14 @@ func GetMetrics(sources []DataSource, aggregated bool) error { } } } + return nil } func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) { defer trace.CatchPanic("crowdsec/acquis") logger.Infof("transformer started") + for { select { case <-AcquisTomb.Dying(): @@ -277,15 +348,18 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo return case evt := <-transformChan: logger.Tracef("Received event %s", evt.Line.Raw) + out, err := expr.Run(transformRuntime, map[string]interface{}{"evt": &evt}) if err != nil { logger.Errorf("while running transform expression: %s, sending event as-is", err) output <- evt } + if out == nil { logger.Errorf("transform expression returned nil, sending event as-is") output <- evt } + switch v := out.(type) { case string: logger.Tracef("transform expression returned %s", v) @@ -293,18 +367,22 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo output <- evt case []interface{}: logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content + for _, line := range v { l, ok := line.(string) if !ok { logger.Errorf("transform expression returned []interface{}, but cannot assert an element to string") output <- evt + continue } + evt.Line.Raw = l output <- evt } case []string: logger.Tracef("transform expression returned %v", v) + for _, line := range v { evt.Line.Raw = line output <- evt @@ -317,49 +395,58 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo } } -func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +func StartAcquisition(ctx context.Context, sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { // Don't wait if we have no sources, as it will hang forever if len(sources) == 0 { return nil } - for i := 0; i < len(sources); i++ { - subsrc := sources[i] //ensure its a copy + for i := range sources { + subsrc := sources[i] // ensure its a copy log.Debugf("starting one source %d/%d ->> %T", i, len(sources), subsrc) AcquisTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis") + var err error outChan := output + log.Debugf("datasource %s UUID: %s", subsrc.GetName(), subsrc.GetUuid()) + if transformRuntime, ok := transformRuntimes[subsrc.GetUuid()]; ok { log.Infof("transform expression found for datasource %s", subsrc.GetName()) + transformChan := make(chan types.Event) outChan = transformChan transformLogger := log.WithFields(log.Fields{ "component": "transform", "datasource": subsrc.GetName(), }) + AcquisTomb.Go(func() error { transform(outChan, output, AcquisTomb, transformRuntime, transformLogger) return nil }) } + if subsrc.GetMode() == configuration.TAIL_MODE { - err = subsrc.StreamingAcquisition(outChan, AcquisTomb) + err = subsrc.StreamingAcquisition(ctx, outChan, AcquisTomb) } else { err = subsrc.OneShotAcquisition(outChan, AcquisTomb) } + if err != nil { - //if one of the acqusition returns an error, we kill the others to properly shutdown + // if one of the acqusition returns an error, we kill the others to properly shutdown AcquisTomb.Kill(err) } + return nil }) } /*return only when acquisition is over (cat) or never (tail)*/ err := AcquisTomb.Wait() + return err } diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index 44b3878e1d0..e82b3df54c2 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -1,6 +1,8 @@ package acquisition import ( + "context" + "errors" "fmt" "strings" "testing" @@ -35,7 +37,7 @@ func (f *MockSource) UnmarshalConfig(cfg []byte) error { return nil } -func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockSource) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if err := f.UnmarshalConfig(cfg); err != nil { return err @@ -50,21 +52,23 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error { } if f.Toto == "" { - return fmt.Errorf("expect non-empty toto") + return errors.New("expect non-empty toto") } return nil } -func (f *MockSource) GetMode() string { return f.Mode } -func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) CanRun() error { return nil } -func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSource) Dump() interface{} { return f } -func (f *MockSource) GetName() string { return "mock" } +func (f *MockSource) GetMode() string { return f.Mode } +func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSource) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSource) CanRun() error { return nil } +func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSource) Dump() interface{} { return f } +func (f *MockSource) GetName() string { return "mock" } func (f *MockSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockSource) GetUuid() string { return "" } @@ -73,18 +77,13 @@ type MockSourceCantRun struct { MockSource } -func (f *MockSourceCantRun) CanRun() error { return fmt.Errorf("can't run bro") } +func (f *MockSourceCantRun) CanRun() error { return errors.New("can't run bro") } func (f *MockSourceCantRun) GetName() string { return "mock_cant_run" } // appendMockSource is only used to add mock source for tests func appendMockSource() { - if GetDataSourceIface("mock") == nil { - AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } - } - - if GetDataSourceIface("mock_cant_run") == nil { - AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } - } + AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } + AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } } func TestDataSourceConfigure(t *testing.T) { @@ -149,7 +148,7 @@ labels: log_level: debug source: tutu `, - ExpectedError: "cannot find source tutu", + ExpectedError: "unknown data source tutu", }, { TestName: "mismatch_config", @@ -178,12 +177,12 @@ wowo: ajsajasjas } for _, tc := range tests { - tc := tc t.Run(tc.TestName, func(t *testing.T) { common := configuration.DataSourceCommonCfg{} yaml.Unmarshal([]byte(tc.String), &common) - ds, err := DataSourceConfigure(common) + ds, err := DataSourceConfigure(common, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } @@ -270,7 +269,7 @@ func TestLoadAcquisitionFromFile(t *testing.T) { Config: csconfig.CrowdsecServiceCfg{ AcquisitionFiles: []string{"test_files/bad_source.yaml"}, }, - ExpectedError: "unknown data source does_not_exist in test_files/bad_source.yaml", + ExpectedError: "in file test_files/bad_source.yaml (position: 0) - unknown data source does_not_exist", }, { TestName: "invalid_filetype_config", @@ -281,10 +280,10 @@ func TestLoadAcquisitionFromFile(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.TestName, func(t *testing.T) { - dss, err := LoadAcquisitionFromFile(&tc.Config) + dss, err := LoadAcquisitionFromFile(&tc.Config, nil) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } @@ -305,7 +304,7 @@ type MockCat struct { logger *log.Entry } -func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockCat) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if f.Mode == "" { f.Mode = configuration.CAT_MODE @@ -322,7 +321,7 @@ func (f *MockCat) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockCat) GetName() string { return "mock_cat" } func (f *MockCat) GetMode() string { return "cat" } func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { - for i := 0; i < 10; i++ { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt @@ -330,15 +329,16 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro return nil } -func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { - return fmt.Errorf("can't run in tail") + +func (f *MockCat) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return errors.New("can't run in tail") } func (f *MockCat) CanRun() error { return nil } func (f *MockCat) GetMetrics() []prometheus.Collector { return nil } func (f *MockCat) GetAggregMetrics() []prometheus.Collector { return nil } func (f *MockCat) Dump() interface{} { return f } func (f *MockCat) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockCat) GetUuid() string { return "" } @@ -349,7 +349,7 @@ type MockTail struct { logger *log.Entry } -func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error { +func (f *MockTail) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { f.logger = logger if f.Mode == "" { f.Mode = configuration.TAIL_MODE @@ -366,14 +366,16 @@ func (f *MockTail) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockTail) GetName() string { return "mock_tail" } func (f *MockTail) GetMode() string { return "tail" } func (f *MockTail) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { - return fmt.Errorf("can't run in cat mode") + return errors.New("can't run in cat mode") } -func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { - for i := 0; i < 10; i++ { + +func (f *MockTail) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } + <-t.Dying() return nil @@ -383,13 +385,14 @@ func (f *MockTail) GetMetrics() []prometheus.Collector { return nil } func (f *MockTail) GetAggregMetrics() []prometheus.Collector { return nil } func (f *MockTail) Dump() interface{} { return f } func (f *MockTail) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("not supported") + return errors.New("not supported") } func (f *MockTail) GetUuid() string { return "" } -//func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +// func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { func TestStartAcquisitionCat(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockCat{}, } @@ -397,7 +400,7 @@ func TestStartAcquisitionCat(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -417,6 +420,7 @@ READLOOP: } func TestStartAcquisitionTail(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTail{}, } @@ -424,7 +428,7 @@ func TestStartAcquisitionTail(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -451,18 +455,20 @@ type MockTailError struct { MockTail } -func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { - for i := 0; i < 10; i++ { +func (f *MockTailError) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } - t.Kill(fmt.Errorf("got error (tomb)")) - return fmt.Errorf("got error") + t.Kill(errors.New("got error (tomb)")) + + return errors.New("got error") } func TestStartAcquisitionTailError(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTailError{}, } @@ -470,7 +476,7 @@ func TestStartAcquisitionTailError(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { t.Errorf("expected error, got '%s'", err) } }() @@ -486,7 +492,7 @@ READLOOP: } } assert.Equal(t, 10, count) - //acquisTomb.Kill(nil) + // acquisTomb.Kill(nil) time.Sleep(1 * time.Second) cstest.RequireErrorContains(t, acquisTomb.Err(), "got error (tomb)") } @@ -497,20 +503,24 @@ type MockSourceByDSN struct { logger *log.Entry //nolint: unused } -func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } -func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry) error { return nil } -func (f *MockSourceByDSN) GetMode() string { return f.Mode } -func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) CanRun() error { return nil } -func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) Dump() interface{} { return f } -func (f *MockSourceByDSN) GetName() string { return "mockdsn" } +func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } +func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { + return nil +} +func (f *MockSourceByDSN) GetMode() string { return f.Mode } +func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSourceByDSN) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSourceByDSN) CanRun() error { return nil } +func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) Dump() interface{} { return f } +func (f *MockSourceByDSN) GetName() string { return "mockdsn" } func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { dsn = strings.TrimPrefix(dsn, "mockdsn://") if dsn != "test_expect" { - return fmt.Errorf("unexpected value") + return errors.New("unexpected value") } return nil @@ -541,12 +551,9 @@ func TestConfigureByDSN(t *testing.T) { }, } - if GetDataSourceIface("mockdsn") == nil { - AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } - } + AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } for _, tc := range tests { - tc := tc t.Run(tc.dsn, func(t *testing.T) { srcs, err := LoadAcquisitionFromDSN(tc.dsn, map[string]string{"type": "test_label"}, "") cstest.RequireErrorContains(t, err, tc.ExpectedError) diff --git a/pkg/acquisition/appsec.go b/pkg/acquisition/appsec.go new file mode 100644 index 00000000000..81616d3d2b8 --- /dev/null +++ b/pkg/acquisition/appsec.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_appsec + +package acquisition + +import ( + appsecacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("appsec", func() DataSource { return &appsecacquisition.AppsecSource{} }) +} diff --git a/pkg/acquisition/cloudwatch.go b/pkg/acquisition/cloudwatch.go new file mode 100644 index 00000000000..e6b3d3e3e53 --- /dev/null +++ b/pkg/acquisition/cloudwatch.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_cloudwatch + +package acquisition + +import ( + cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("cloudwatch", func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }) +} diff --git a/pkg/acquisition/configuration/configuration.go b/pkg/acquisition/configuration/configuration.go index 5ec1a4ac4c3..3e27da1b9e6 100644 --- a/pkg/acquisition/configuration/configuration.go +++ b/pkg/acquisition/configuration/configuration.go @@ -19,3 +19,14 @@ type DataSourceCommonCfg struct { var TAIL_MODE = "tail" var CAT_MODE = "cat" var SERVER_MODE = "server" // No difference with tail, just a bit more verbose + +const ( + METRICS_NONE = iota + METRICS_AGGREGATE + METRICS_FULL +) + +const ( + CFG_METRICS_AGGREGATE = "aggregated" + CFG_METRICS_FULL = "full" +) diff --git a/pkg/acquisition/docker.go b/pkg/acquisition/docker.go new file mode 100644 index 00000000000..3bf792a039a --- /dev/null +++ b/pkg/acquisition/docker.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_docker + +package acquisition + +import ( + dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("docker", func() DataSource { return &dockeracquisition.DockerSource{} }) +} diff --git a/pkg/acquisition/file.go b/pkg/acquisition/file.go new file mode 100644 index 00000000000..1ff2e4a3c0e --- /dev/null +++ b/pkg/acquisition/file.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_file + +package acquisition + +import ( + fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("file", func() DataSource { return &fileacquisition.FileSource{} }) +} diff --git a/pkg/acquisition/journalctl.go b/pkg/acquisition/journalctl.go new file mode 100644 index 00000000000..691f961ae77 --- /dev/null +++ b/pkg/acquisition/journalctl.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_journalctl + +package acquisition + +import ( + journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("journalctl", func() DataSource { return &journalctlacquisition.JournalCtlSource{} }) +} diff --git a/pkg/acquisition/k8s.go b/pkg/acquisition/k8s.go new file mode 100644 index 00000000000..cb9446be285 --- /dev/null +++ b/pkg/acquisition/k8s.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_k8saudit + +package acquisition + +import ( + k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("k8s-audit", func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }) +} diff --git a/pkg/acquisition/kafka.go b/pkg/acquisition/kafka.go new file mode 100644 index 00000000000..7d315d87feb --- /dev/null +++ b/pkg/acquisition/kafka.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kafka + +package acquisition + +import ( + kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kafka", func() DataSource { return &kafkaacquisition.KafkaSource{} }) +} diff --git a/pkg/acquisition/kinesis.go b/pkg/acquisition/kinesis.go new file mode 100644 index 00000000000..b41372e7fb9 --- /dev/null +++ b/pkg/acquisition/kinesis.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kinesis + +package acquisition + +import ( + kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kinesis", func() DataSource { return &kinesisacquisition.KinesisSource{} }) +} diff --git a/pkg/acquisition/loki.go b/pkg/acquisition/loki.go new file mode 100644 index 00000000000..1eed6686591 --- /dev/null +++ b/pkg/acquisition/loki.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_loki + +package acquisition + +import ( + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("loki", func() DataSource { return &loki.LokiSource{} }) +} diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 030724fc3e9..5161b631c33 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -3,23 +3,26 @@ package appsecacquisition import ( "context" "encoding/json" + "errors" "fmt" + "net" "net/http" + "os" "sync" "time" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - - "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" - "github.com/crowdsecurity/crowdsec/pkg/appsec" - "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/crowdsecurity/go-cs-lib/trace" "github.com/google/uuid" - "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/types" ) const ( @@ -27,13 +30,12 @@ const ( OutOfBand = "outofband" ) -var ( - DefaultAuthCacheDuration = (1 * time.Minute) -) +var DefaultAuthCacheDuration = (1 * time.Minute) // configuration structure of the acquis for the application security engine type AppsecSourceConfig struct { ListenAddr string `yaml:"listen_addr"` + ListenSocket string `yaml:"listen_socket"` CertFilePath string `yaml:"cert_file"` KeyFilePath string `yaml:"key_file"` Path string `yaml:"path"` @@ -46,6 +48,7 @@ type AppsecSourceConfig struct { // runtime structure of AppsecSourceConfig type AppsecSource struct { + metricsLevel int config AppsecSourceConfig logger *log.Entry mux *http.ServeMux @@ -56,7 +59,7 @@ type AppsecSource struct { AppsecConfigs map[string]appsec.AppsecConfig lapiURL string AuthCache AuthCache - AppsecRunners []AppsecRunner //one for each go-routine + AppsecRunners []AppsecRunner // one for each go-routine } // Struct to handle cache of authentication @@ -91,13 +94,12 @@ type BodyResponse struct { } func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { - err := yaml.UnmarshalStrict(yamlConfig, &w.config) if err != nil { - return errors.Wrap(err, "Cannot parse appsec configuration") + return fmt.Errorf("cannot parse appsec configuration: %w", err) } - if w.config.ListenAddr == "" { + if w.config.ListenAddr == "" && w.config.ListenSocket == "" { w.config.ListenAddr = "127.0.0.1:7422" } @@ -119,11 +121,16 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { } if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" { - return fmt.Errorf("appsec_config or appsec_config_path must be set") + return errors.New("appsec_config or appsec_config_path must be set") } if w.config.Name == "" { - w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + if w.config.ListenSocket != "" && w.config.ListenAddr == "" { + w.config.Name = w.config.ListenSocket + } + if w.config.ListenSocket == "" { + w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + } } csConfig := csconfig.GetConfig() @@ -141,13 +148,13 @@ func (w *AppsecSource) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{AppsecReqCounter, AppsecBlockCounter, AppsecRuleHits, AppsecOutbandParsingHistogram, AppsecInbandParsingHistogram, AppsecGlobalParsingHistogram} } -func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { err := w.UnmarshalConfig(yamlConfig) if err != nil { - return errors.Wrap(err, "unable to parse appsec configuration") + return fmt.Errorf("unable to parse appsec configuration: %w", err) } w.logger = logger - + w.metricsLevel = MetricsLevel w.logger.Tracef("Appsec configuration: %+v", w.config) if w.config.AuthCacheDuration == nil { @@ -165,64 +172,61 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry) error { w.InChan = make(chan appsec.ParsedRequest) appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} - //let's load the associated appsec_config: + // let's load the associated appsec_config: if w.config.AppsecConfigPath != "" { err := appsecCfg.LoadByPath(w.config.AppsecConfigPath) if err != nil { - return fmt.Errorf("unable to load appsec_config : %s", err) + return fmt.Errorf("unable to load appsec_config: %w", err) } } else if w.config.AppsecConfig != "" { err := appsecCfg.Load(w.config.AppsecConfig) if err != nil { - return fmt.Errorf("unable to load appsec_config : %s", err) + return fmt.Errorf("unable to load appsec_config: %w", err) } } else { - return fmt.Errorf("no appsec_config provided") + return errors.New("no appsec_config provided") } w.AppsecRuntime, err = appsecCfg.Build() if err != nil { - return fmt.Errorf("unable to build appsec_config : %s", err) + return fmt.Errorf("unable to build appsec_config: %w", err) } err = w.AppsecRuntime.ProcessOnLoadRules() - if err != nil { - return fmt.Errorf("unable to process on load rules : %s", err) + return fmt.Errorf("unable to process on load rules: %w", err) } w.AppsecRunners = make([]AppsecRunner, w.config.Routines) - for nbRoutine := 0; nbRoutine < w.config.Routines; nbRoutine++ { + for nbRoutine := range w.config.Routines { appsecRunnerUUID := uuid.New().String() - //we copy AppsecRutime for each runner + // we copy AppsecRutime for each runner wrt := *w.AppsecRuntime wrt.Logger = w.logger.Dup().WithField("runner_uuid", appsecRunnerUUID) runner := AppsecRunner{ - inChan: w.InChan, - UUID: appsecRunnerUUID, - logger: w.logger.WithFields(log.Fields{ - "runner_uuid": appsecRunnerUUID, - }), + inChan: w.InChan, + UUID: appsecRunnerUUID, + logger: w.logger.WithField("runner_uuid", appsecRunnerUUID), AppsecRuntime: &wrt, Labels: w.config.Labels, } err := runner.Init(appsecCfg.GetDataDir()) if err != nil { - return fmt.Errorf("unable to initialize runner : %s", err) + return fmt.Errorf("unable to initialize runner: %w", err) } w.AppsecRunners[nbRoutine] = runner } w.logger.Infof("Created %d appsec runners", len(w.AppsecRunners)) - //We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec + // We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec w.mux.HandleFunc(w.config.Path, w.appsecHandler) return nil } func (w *AppsecSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { - return fmt.Errorf("AppSec datasource does not support command line acquisition") + return errors.New("AppSec datasource does not support command line acquisition") } func (w *AppsecSource) GetMode() string { @@ -234,41 +238,61 @@ func (w *AppsecSource) GetName() string { } func (w *AppsecSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("AppSec datasource does not support command line acquisition") + return errors.New("AppSec datasource does not support command line acquisition") } -func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { w.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live") w.logger.Infof("%d appsec runner to start", len(w.AppsecRunners)) for _, runner := range w.AppsecRunners { - runner := runner runner.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live/runner") return runner.Run(t) }) } - - w.logger.Infof("Starting Appsec server on %s%s", w.config.ListenAddr, w.config.Path) t.Go(func() error { - var err error - if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { - err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) - } else { - err = w.server.ListenAndServe() + if w.config.ListenSocket != "" { + w.logger.Infof("creating unix socket %s", w.config.ListenSocket) + _ = os.RemoveAll(w.config.ListenSocket) + listener, err := net.Listen("unix", w.config.ListenSocket) + if err != nil { + return fmt.Errorf("appsec server failed: %w", err) + } + defer listener.Close() + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ServeTLS(listener, w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("appsec server failed: %w", err) + } } - - if err != nil && err != http.ErrServerClosed { - return errors.Wrap(err, "Appsec server failed") + return nil + }) + t.Go(func() error { + var err error + if w.config.ListenAddr != "" { + w.logger.Infof("creating TCP server on %s", w.config.ListenAddr) + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.ListenAndServe() + } + + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("appsec server failed: %w", err) + } } return nil }) <-t.Dying() - w.logger.Infof("Stopping Appsec server on %s%s", w.config.ListenAddr, w.config.Path) - //xx let's clean up the appsec runners :) + w.logger.Info("Shutting down Appsec server") + // xx let's clean up the appsec runners :) appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) w.server.Shutdown(context.TODO()) return nil @@ -308,7 +332,6 @@ func (w *AppsecSource) IsAuth(apiKey string) bool { defer resp.Body.Close() return resp.StatusCode == http.StatusOK - } // should this be in the runner ? @@ -354,24 +377,25 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { w.InChan <- parsedRequest + /* + response is a copy of w.AppSecRuntime.Response that is safe to use. + As OutOfBand might still be running, the original one can be modified + */ response := <-parsedRequest.ResponseChannel - statusCode := http.StatusOK if response.InBandInterrupt { - statusCode = http.StatusForbidden AppsecBlockCounter.With(prometheus.Labels{"source": parsedRequest.RemoteAddrNormalized, "appsec_engine": parsedRequest.AppsecEngine}).Inc() } - appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger) + statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger) logger.Debugf("Response: %+v", appsecResponse) rw.WriteHeader(statusCode) body, err := json.Marshal(appsecResponse) if err != nil { - logger.Errorf("unable to marshal response: %s", err) + logger.Errorf("unable to serialize response: %s", err) rw.WriteHeader(http.StatusInternalServerError) } else { rw.Write(body) } - } diff --git a/pkg/acquisition/modules/appsec/appsec_hooks_test.go b/pkg/acquisition/modules/appsec/appsec_hooks_test.go new file mode 100644 index 00000000000..c549d2ef1d1 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_hooks_test.go @@ -0,0 +1,894 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecOnMatchHooks(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "no rule : check return code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 403, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change return code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(413)"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 413, responses[0].UserHTTPResponseCode) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to a non standard one (log)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('log')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, "log", responses[0].Action) + require.Equal(t, 403, responses[0].BouncerHTTPResponseCode) + require.Equal(t, 403, responses[0].UserHTTPResponseCode) + }, + }, + { + name: "on_match: change action to another standard one (allow)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to another standard one (ban)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, responses, 1) + //note: SetAction normalizes deny, ban and block to ban + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to another standard one (captcha)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, responses, 1) + //note: SetAction normalizes deny, ban and block to ban + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + }, + }, + { + name: "on_match: change action to a non standard one", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('foobar')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "on_match: cancel alert", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true && LogInfo('XX -> %s', evt.Appsec.MatchedRules.GetName())", Apply: []string{"CancelAlert()"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.LOG, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: cancel event", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}, OnSuccess: "break"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.BanRemediation, responses[0].Action) + }, + }, + { + name: "on_match: on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}, OnSuccess: "continue"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.APPSEC, events[0].Type) + require.Len(t, responses, 1) + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecPreEvalHooks(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic pre_eval hook to disable inband rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1 == 1", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "Basic pre_eval fails to disable rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1 ==2", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + + }, + }, + { + name: "pre_eval : disable inband by tag", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByTag('crowdsec-rulez')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : disable inband by ID", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByID(1516470898)"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : disable inband by name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"RemoveInBandRuleByName('rulez')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : outofband default behavior", + expected_load_ok: true, + outofband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 1) + require.Equal(t, types.LOG, events[0].Type) + require.True(t, events[0].Appsec.HasOutBandMatches) + require.False(t, events[0].Appsec.HasInBandMatches) + require.Len(t, events[0].Appsec.MatchedRules, 1) + require.Equal(t, "rulez", events[0].Appsec.MatchedRules[0]["msg"]) + //maybe surprising, but response won't mention OOB event, as it's sent as soon as the inband phase is over. + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "pre_eval : set remediation by tag", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByTag('crowdsec-rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "pre_eval : set remediation by name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByName('rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + { + name: "pre_eval : set remediation by ID", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Apply: []string{"SetRemediationByID(1516470898, 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "pre_eval : on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar')"}, OnSuccess: "continue"}, + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar2')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar2", responses[0].Action) + }, + }, + { + name: "pre_eval : on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rulez", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar')"}, OnSuccess: "break"}, + {Filter: "1==1", Apply: []string{"SetRemediationByName('rulez', 'foobar2')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Len(t, responses, 1) + require.Equal(t, "foobar", responses[0].Action) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRemediationConfigHooks(t *testing.T) { + + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetRemediation('captcha')"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "SetRemediation", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{{Apply: []string{"SetReturnCode(418)"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (: + + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} +func TestOnMatchRemediationHooks(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "set remediation to allow with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "set remediation to captcha + custom user code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + { + name: "on_match: on_success break", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}, OnSuccess: "break"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + { + name: "on_match: on_success continue", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: appsec.AllowRemediation, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}, OnSuccess: "continue"}, + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + spew.Dump(responses) + spew.Dump(appsecResponse) + + log.Errorf("http status : %d", statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + require.Equal(t, http.StatusForbidden, statusCode) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_lnx_test.go b/pkg/acquisition/modules/appsec/appsec_lnx_test.go new file mode 100644 index 00000000000..61dfc536f5e --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_lnx_test.go @@ -0,0 +1,74 @@ +//go:build !windows + +package appsecacquisition + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecRuleTransformsOthers(t *testing.T) { + log.SetLevel(log.TraceLevel) + + tests := []appsecRuleTest{ + { + name: "normalizepath", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "b/c"}, + Transform: []string{"normalizepath"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=a/../b/c", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "normalizepath #2", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "b/c/"}, + Transform: []string{"normalizepath"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=a/../b/c/////././././", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_remediation_test.go b/pkg/acquisition/modules/appsec/appsec_remediation_test.go new file mode 100644 index 00000000000..06016b6251f --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_remediation_test.go @@ -0,0 +1,319 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecDefaultPassRemediation(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic non-matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: pass", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassAction: captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + DefaultPassAction: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) //@tko: body is captcha, but as it's 200, captcha won't be showed to user + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 200, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "DefaultPassHTTPCode: 200", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Args: url.Values{"foo": []string{"tutu"}}, + }, + UserPassedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecDefaultRemediation(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to ban (default)", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "ban", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "custom remediation + HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + DefaultRemediation: "foobar", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, "foobar", responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "foobar", appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_rules_test.go b/pkg/acquisition/modules/appsec/appsec_rules_test.go new file mode 100644 index 00000000000..909f16357ed --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_rules_test.go @@ -0,0 +1,859 @@ +package appsecacquisition + +import ( + "net/http" + "net/url" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestAppsecRuleMatches(t *testing.T) { + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic non-matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Len(t, responses, 1) + require.False(t, responses[0].InBandInterrupt) + require.False(t, responses[0].OutOfBandInterrupt) + }, + }, + { + name: "default remediation to allow", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "allow", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.AllowRemediation, responses[0].Action) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus) + }, + }, + { + name: "default remediation to captcha", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + DefaultRemediation: "captcha", + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.CaptchaRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action) + require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus) + }, + }, + { + name: "no default remediation / custom user HTTP code", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"toto"}}, + }, + UserBlockedHTTPCode: 418, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Equal(t, appsec.BanRemediation, responses[0].Action) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, appsec.BanRemediation, appsecResponse.Action) + require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus) + }, + }, + { + name: "no match but try to set remediation to captcha with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set user HTTP code with on_match hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + on_match: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "no match but try to set remediation with pre_eval hook", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule42", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + pre_eval: []appsec.Hook{ + {Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}}, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Args: url.Values{"foo": []string{"bla"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Empty(t, events) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, appsec.AllowRemediation, appsecResponse.Action) + }, + }, + { + name: "Basic matching in cookies", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"foo=toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in all cookies", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES"}, + Match: appsec_rule.Match{Type: "regex", Value: "^tutu"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"foo=toto; bar=tutu"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in cookie name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"COOKIES_NAMES"}, + Match: appsec_rule.Match{Type: "regex", Value: "^tutu"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Cookie": []string{"bar=tutu; tututata=toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching in multipart file name", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"FILES"}, + Match: appsec_rule.Match{Type: "regex", Value: "\\.php$"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, + Body: []byte(` +--boundary +Content-Disposition: form-data; name="foo"; filename="bar.php" +Content-Type: application/octet-stream + +toto +--boundary--`), + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRuleTransforms(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "Basic matching rule", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/toto"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "lowercase", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/toto"}, + Transform: []string{"lowercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/TOTO", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "uppercase", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/TOTO"}, + Transform: []string{"uppercase"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "b64decode", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + Transform: []string{"b64decode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=dG90bw", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "b64decode with extra padding", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + Transform: []string{"b64decode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=dG90bw===", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "length", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "gte", Value: "3"}, + Transform: []string{"length"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=toto", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "urldecode", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "BB/A"}, + Transform: []string{"urldecode"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=%42%42%2F%41", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "trim", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: appsec_rule.Match{Type: "equals", Value: "BB/A"}, + Transform: []string{"urldecode", "trim"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/?foo=%20%20%42%42%2F%41%20%20", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} + +func TestAppsecRuleZones(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "rule: ARGS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?something=toto&foobar=smth", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: ARGS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?something=toto&foobar=smth", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: BODY_ARGS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"BODY_ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"BODY_ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("smth=toto&foobar=other"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: BODY_ARGS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"BODY_ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"BODY_ARGS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("smth=toto&foobar=other"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: HEADERS", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"HEADERS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"HEADERS"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Headers: http.Header{"foobar": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: HEADERS_NAMES", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"HEADERS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"HEADERS_NAMES"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Headers: http.Header{"foobar": []string{"toto"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule2", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: METHOD", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"METHOD"}, + Match: appsec_rule.Match{Type: "equals", Value: "GET"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: PROTOCOL", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"PROTOCOL"}, + Match: appsec_rule.Match{Type: "contains", Value: "3.1"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Proto: "HTTP/3.1", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: URI", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI"}, + Match: appsec_rule.Match{Type: "equals", Value: "/foobar"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: URI_FULL", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"URI_FULL"}, + Match: appsec_rule.Match{Type: "equals", Value: "/foobar?a=b"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/foobar?a=b", + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + { + name: "rule: RAW_BODY", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"RAW_BODY"}, + Match: appsec_rule.Match{Type: "equals", Value: "foobar=42421"}, + }, + }, + input_request: appsec.ParsedRequest{ + RemoteAddr: "1.2.3.4", + Method: "GET", + URI: "/", + Body: []byte("foobar=42421"), + Headers: http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + require.Equal(t, types.LOG, events[1].Type) + require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go index a9d74aa8f63..de34b62d704 100644 --- a/pkg/acquisition/modules/appsec/appsec_runner.go +++ b/pkg/acquisition/modules/appsec/appsec_runner.go @@ -6,15 +6,17 @@ import ( "slices" "time" - "github.com/crowdsecurity/coraza/v3" - corazatypes "github.com/crowdsecurity/coraza/v3/types" - "github.com/crowdsecurity/crowdsec/pkg/appsec" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + "github.com/crowdsecurity/coraza/v3" + corazatypes "github.com/crowdsecurity/coraza/v3/types" + + // load body processors via init() _ "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec/bodyprocessors" + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/types" ) // that's the runtime structure of the Application security engine as seen from the acquis @@ -165,7 +167,7 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap return nil } - if request.Body != nil && len(request.Body) > 0 { + if len(request.Body) > 0 { in, _, err = request.Tx.WriteRequestBody(request.Body) if err != nil { r.logger.Errorf("unable to write request body : %s", err) @@ -177,7 +179,6 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap } in, err = request.Tx.ProcessRequestBody() - if err != nil { r.logger.Errorf("unable to process request body : %s", err) return err @@ -226,7 +227,8 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { if in := request.Tx.Interruption(); in != nil { r.logger.Debugf("inband rules matched : %d", in.RuleID) r.AppsecRuntime.Response.InBandInterrupt = true - r.AppsecRuntime.Response.HTTPResponseCode = r.AppsecRuntime.Config.BlockedHTTPCode + r.AppsecRuntime.Response.BouncerHTTPResponseCode = r.AppsecRuntime.Config.BouncerBlockedHTTPCode + r.AppsecRuntime.Response.UserHTTPResponseCode = r.AppsecRuntime.Config.UserBlockedHTTPCode r.AppsecRuntime.Response.Action = r.AppsecRuntime.DefaultRemediation if _, ok := r.AppsecRuntime.RemediationById[in.RuleID]; ok { @@ -252,7 +254,9 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { r.logger.Errorf("unable to generate appsec event : %s", err) return } - r.outChan <- *appsecOvlfw + if appsecOvlfw != nil { + r.outChan <- *appsecOvlfw + } } // Should the in band match trigger an event ? diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go index 2a58580137d..d2079b43726 100644 --- a/pkg/acquisition/modules/appsec/appsec_test.go +++ b/pkg/acquisition/modules/appsec/appsec_test.go @@ -1,651 +1,34 @@ package appsecacquisition import ( - "net/url" "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/appsec" - "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/davecgh/go-spew/spew" "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" -) -/* -Missing tests (wip): - - GenerateResponse - - evt.Appsec and it's subobjects and methods -*/ + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + "github.com/crowdsecurity/crowdsec/pkg/types" +) type appsecRuleTest struct { - name string - expected_load_ok bool - inband_rules []appsec_rule.CustomRule - outofband_rules []appsec_rule.CustomRule - on_load []appsec.Hook - pre_eval []appsec.Hook - post_eval []appsec.Hook - on_match []appsec.Hook - input_request appsec.ParsedRequest - output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse) -} - -func TestAppsecOnMatchHooks(t *testing.T) { - tests := []appsecRuleTest{ - { - name: "no rule : check return code", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - require.Equal(t, types.LOG, events[1].Type) - require.Len(t, responses, 1) - require.Equal(t, 403, responses[0].HTTPResponseCode) - require.Equal(t, "ban", responses[0].Action) - - }, - }, - { - name: "on_match: change return code", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(413)"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - require.Equal(t, types.LOG, events[1].Type) - require.Len(t, responses, 1) - require.Equal(t, 413, responses[0].HTTPResponseCode) - require.Equal(t, "ban", responses[0].Action) - }, - }, - { - name: "on_match: change action to another standard one (log)", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('log')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - require.Equal(t, types.LOG, events[1].Type) - require.Len(t, responses, 1) - require.Equal(t, "log", responses[0].Action) - }, - }, - { - name: "on_match: change action to another standard one (allow)", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - require.Equal(t, types.LOG, events[1].Type) - require.Len(t, responses, 1) - require.Equal(t, "allow", responses[0].Action) - }, - }, - { - name: "on_match: change action to another standard one (deny/ban/block)", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('deny')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, responses, 1) - //note: SetAction normalizes deny, ban and block to ban - require.Equal(t, "ban", responses[0].Action) - }, - }, - { - name: "on_match: change action to another standard one (captcha)", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, responses, 1) - //note: SetAction normalizes deny, ban and block to ban - require.Equal(t, "captcha", responses[0].Action) - }, - }, - { - name: "on_match: change action to a non standard one", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"SetRemediation('foobar')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - require.Equal(t, types.LOG, events[1].Type) - require.Len(t, responses, 1) - require.Equal(t, "foobar", responses[0].Action) - }, - }, - { - name: "on_match: cancel alert", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule42", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true && LogInfo('XX -> %s', evt.Appsec.MatchedRules.GetName())", Apply: []string{"CancelAlert()"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 1) - require.Equal(t, types.LOG, events[0].Type) - require.Len(t, responses, 1) - require.Equal(t, "ban", responses[0].Action) - }, - }, - { - name: "on_match: cancel event", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule42", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - on_match: []appsec.Hook{ - {Filter: "IsInBand == true", Apply: []string{"CancelEvent()"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 1) - require.Equal(t, types.APPSEC, events[0].Type) - require.Len(t, responses, 1) - require.Equal(t, "ban", responses[0].Action) - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - loadAppSecEngine(test, t) - }) - } -} - -func TestAppsecPreEvalHooks(t *testing.T) { - /* - [x] basic working hook - [x] basic failing hook - [ ] test the "OnSuccess" feature - [ ] test multiple competing hooks - [ ] test the variety of helpers - */ - tests := []appsecRuleTest{ - { - name: "Basic on_load hook to disable inband rule", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Filter: "1 == 1", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Empty(t, events) - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - { - name: "Basic on_load fails to disable rule", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Filter: "1 ==2", Apply: []string{"RemoveInBandRuleByName('rule1')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - - require.Equal(t, types.LOG, events[1].Type) - require.True(t, events[1].Appsec.HasInBandMatches) - require.Len(t, events[1].Appsec.MatchedRules, 1) - require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) - - require.Len(t, responses, 1) - require.True(t, responses[0].InBandInterrupt) - - }, - }, - { - name: "on_load : disable inband by tag", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"RemoveInBandRuleByTag('crowdsec-rulez')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Empty(t, events) - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - { - name: "on_load : disable inband by ID", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"RemoveInBandRuleByID(1516470898)"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Empty(t, events) - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - { - name: "on_load : disable inband by name", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"RemoveInBandRuleByName('rulez')"}}, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Empty(t, events) - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - { - name: "on_load : outofband default behavior", - expected_load_ok: true, - outofband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 1) - require.Equal(t, types.LOG, events[0].Type) - require.True(t, events[0].Appsec.HasOutBandMatches) - require.False(t, events[0].Appsec.HasInBandMatches) - require.Len(t, events[0].Appsec.MatchedRules, 1) - require.Equal(t, "rulez", events[0].Appsec.MatchedRules[0]["msg"]) - //maybe surprising, but response won't mention OOB event, as it's sent as soon as the inband phase is over. - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - { - name: "on_load : set remediation by tag", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"SetRemediationByTag('crowdsec-rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Len(t, responses, 1) - require.Equal(t, "foobar", responses[0].Action) - }, - }, - { - name: "on_load : set remediation by name", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"SetRemediationByName('rulez', 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Len(t, responses, 1) - require.Equal(t, "foobar", responses[0].Action) - }, - }, - { - name: "on_load : set remediation by ID", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rulez", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - pre_eval: []appsec.Hook{ - {Apply: []string{"SetRemediationByID(1516470898, 'foobar')"}}, //rule ID is generated at runtime. If you change rule, it will break the test (: - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Len(t, responses, 1) - require.Equal(t, "foobar", responses[0].Action) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - loadAppSecEngine(test, t) - }) - } -} -func TestAppsecRuleMatches(t *testing.T) { - - /* - [x] basic matching rule - [x] basic non-matching rule - [ ] test the transformation - [ ] ? - */ - tests := []appsecRuleTest{ - { - name: "Basic matching rule", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"toto"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Len(t, events, 2) - require.Equal(t, types.APPSEC, events[0].Type) - - require.Equal(t, types.LOG, events[1].Type) - require.True(t, events[1].Appsec.HasInBandMatches) - require.Len(t, events[1].Appsec.MatchedRules, 1) - require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) - - require.Len(t, responses, 1) - require.True(t, responses[0].InBandInterrupt) - }, - }, - { - name: "Basic non-matching rule", - expected_load_ok: true, - inband_rules: []appsec_rule.CustomRule{ - { - Name: "rule1", - Zones: []string{"ARGS"}, - Variables: []string{"foo"}, - Match: appsec_rule.Match{Type: "regex", Value: "^toto"}, - Transform: []string{"lowercase"}, - }, - }, - input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", - Method: "GET", - URI: "/urllll", - Args: url.Values{"foo": []string{"tutu"}}, - }, - output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) { - require.Empty(t, events) - require.Len(t, responses, 1) - require.False(t, responses[0].InBandInterrupt) - require.False(t, responses[0].OutOfBandInterrupt) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - loadAppSecEngine(test, t) - }) - } + name string + expected_load_ok bool + inband_rules []appsec_rule.CustomRule + outofband_rules []appsec_rule.CustomRule + on_load []appsec.Hook + pre_eval []appsec.Hook + post_eval []appsec.Hook + on_match []appsec.Hook + BouncerBlockedHTTPCode int + UserBlockedHTTPCode int + UserPassedHTTPCode int + DefaultRemediation string + DefaultPassAction string + input_request appsec.ParsedRequest + output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) } func loadAppSecEngine(test appsecRuleTest, t *testing.T) { @@ -659,7 +42,7 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { InChan := make(chan appsec.ParsedRequest) OutChan := make(chan types.Event) - logger := log.WithFields(log.Fields{"test": test.name}) + logger := log.WithField("test", test.name) //build rules for ridx, rule := range test.inband_rules { @@ -678,7 +61,16 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { outofbandRules = append(outofbandRules, strRule) } - appsecCfg := appsec.AppsecConfig{Logger: logger, OnLoad: test.on_load, PreEval: test.pre_eval, PostEval: test.post_eval, OnMatch: test.on_match} + appsecCfg := appsec.AppsecConfig{Logger: logger, + OnLoad: test.on_load, + PreEval: test.pre_eval, + PostEval: test.post_eval, + OnMatch: test.on_match, + BouncerBlockedHTTPCode: test.BouncerBlockedHTTPCode, + UserBlockedHTTPCode: test.UserBlockedHTTPCode, + UserPassedHTTPCode: test.UserPassedHTTPCode, + DefaultRemediation: test.DefaultRemediation, + DefaultPassAction: test.DefaultPassAction} AppsecRuntime, err := appsecCfg.Build() if err != nil { t.Fatalf("unable to build appsec runtime : %s", err) @@ -724,8 +116,9 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { runner.handleRequest(&input) time.Sleep(50 * time.Millisecond) + + http_status, appsecResponse := AppsecRuntime.GenerateResponse(OutputResponses[0], logger) log.Infof("events : %s", spew.Sdump(OutputEvents)) log.Infof("responses : %s", spew.Sdump(OutputResponses)) - test.output_asserts(OutputEvents, OutputResponses) - + test.output_asserts(OutputEvents, OutputResponses, appsecResponse, http_status) } diff --git a/pkg/acquisition/modules/appsec/appsec_win_test.go b/pkg/acquisition/modules/appsec/appsec_win_test.go new file mode 100644 index 00000000000..a6b8f3a0340 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_win_test.go @@ -0,0 +1,45 @@ +//go:build windows + +package appsecacquisition + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestAppsecRuleTransformsWindows(t *testing.T) { + + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + // { + // name: "normalizepath", + // expected_load_ok: true, + // inband_rules: []appsec_rule.CustomRule{ + // { + // Name: "rule1", + // Zones: []string{"ARGS"}, + // Variables: []string{"foo"}, + // Match: appsec_rule.Match{Type: "equals", Value: "b/c"}, + // Transform: []string{"normalizepath"}, + // }, + // }, + // input_request: appsec.ParsedRequest{ + // RemoteAddr: "1.2.3.4", + // Method: "GET", + // URI: "/?foo=a/../b/c", + // }, + // output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + // require.Len(t, events, 2) + // require.Equal(t, types.APPSEC, events[0].Type) + // require.Equal(t, types.LOG, events[1].Type) + // require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + // }, + // }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/rx_operator.go b/pkg/acquisition/modules/appsec/rx_operator.go index 43aaf9e94be..4b16296fd40 100644 --- a/pkg/acquisition/modules/appsec/rx_operator.go +++ b/pkg/acquisition/modules/appsec/rx_operator.go @@ -5,10 +5,11 @@ import ( "strconv" "unicode/utf8" - "github.com/crowdsecurity/coraza/v3/experimental/plugins" - "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" "github.com/wasilibs/go-re2" "github.com/wasilibs/go-re2/experimental" + + "github.com/crowdsecurity/coraza/v3/experimental/plugins" + "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" ) type rx struct { @@ -50,9 +51,9 @@ func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool { tx.CaptureField(i, c) } return true - } else { - return o.re.MatchString(value) } + + return o.re.MatchString(value) } // RegisterRX registers the rx operator using a WASI implementation instead of Go. diff --git a/pkg/acquisition/modules/appsec/utils.go b/pkg/acquisition/modules/appsec/utils.go index 7600617965a..4fb1a979d14 100644 --- a/pkg/acquisition/modules/appsec/utils.go +++ b/pkg/acquisition/modules/appsec/utils.go @@ -1,70 +1,182 @@ package appsecacquisition import ( - "encoding/json" "fmt" + "net" + "slices" + "strconv" "time" + "github.com/oschwald/geoip2-golang" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/coraza/v3/collection" "github.com/crowdsecurity/coraza/v3/types/variables" + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" ) +var appsecMetaKeys = []string{ + "id", + "name", + "method", + "uri", + "matched_zones", + "msg", +} + +func appendMeta(meta models.Meta, key string, value string) models.Meta { + if value == "" { + return meta + } + + meta = append(meta, &models.MetaItems0{ + Key: key, + Value: value, + }) + + return meta +} + func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { - //if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI + // if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI if !inEvt.Appsec.HasInBandMatches { return nil, nil } + evt := types.Event{} evt.Type = types.APPSEC evt.Process = true + sourceIP := inEvt.Parsed["source_ip"] source := models.Source{ - Value: ptr.Of(inEvt.Parsed["source_ip"]), - IP: inEvt.Parsed["source_ip"], + Value: &sourceIP, + IP: sourceIP, Scope: ptr.Of(types.Ip), } + asndata, err := exprhelpers.GeoIPASNEnrich(sourceIP) + + if err != nil { + log.Errorf("Unable to enrich ip '%s' for ASN: %s", sourceIP, err) + } else if asndata != nil { + record := asndata.(*geoip2.ASN) + source.AsName = record.AutonomousSystemOrganization + source.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) + } + + cityData, err := exprhelpers.GeoIPEnrich(sourceIP) + if err != nil { + log.Errorf("Unable to enrich ip '%s' for geo data: %s", sourceIP, err) + } else if cityData != nil { + record := cityData.(*geoip2.City) + source.Cn = record.Country.IsoCode + source.Latitude = float32(record.Location.Latitude) + source.Longitude = float32(record.Location.Longitude) + } + + rangeData, err := exprhelpers.GeoIPRangeEnrich(sourceIP) + if err != nil { + log.Errorf("Unable to enrich ip '%s' for range: %s", sourceIP, err) + } else if rangeData != nil { + record := rangeData.(*net.IPNet) + source.Range = record.String() + } + evt.Overflow.Sources = make(map[string]models.Source) - evt.Overflow.Sources["ip"] = source + evt.Overflow.Sources[sourceIP] = source alert := models.Alert{} alert.Capacity = ptr.Of(int32(1)) - alert.Events = make([]*models.Event, 0) - alert.Meta = make(models.Meta, 0) - for _, key := range []string{"target_uri", "method"} { + alert.Events = make([]*models.Event, len(evt.Appsec.GetRuleIDs())) - valueByte, err := json.Marshal([]string{inEvt.Parsed[key]}) - if err != nil { - log.Debugf("unable to serialize key %s", key) + now := ptr.Of(time.Now().UTC().Format(time.RFC3339)) + + tmpAppsecContext := make(map[string][]string) + + for _, matched_rule := range inEvt.Appsec.MatchedRules { + evtRule := models.Event{} + + evtRule.Timestamp = now + + evtRule.Meta = make(models.Meta, 0) + + for _, key := range appsecMetaKeys { + if tmpAppsecContext[key] == nil { + tmpAppsecContext[key] = make([]string, 0) + } + + switch value := matched_rule[key].(type) { + case string: + evtRule.Meta = appendMeta(evtRule.Meta, key, value) + + if value != "" && !slices.Contains(tmpAppsecContext[key], value) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], value) + } + case int: + val := strconv.Itoa(value) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + case []string: + for _, v := range value { + evtRule.Meta = appendMeta(evtRule.Meta, key, v) + + if v != "" && !slices.Contains(tmpAppsecContext[key], v) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], v) + } + } + case []int: + for _, v := range value { + val := strconv.Itoa(v) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + } + default: + val := fmt.Sprintf("%v", value) + evtRule.Meta = appendMeta(evtRule.Meta, key, val) + + if val != "" && !slices.Contains(tmpAppsecContext[key], val) { + tmpAppsecContext[key] = append(tmpAppsecContext[key], val) + } + } + } + + alert.Events = append(alert.Events, &evtRule) + } + + metas := make([]*models.MetaItems0, 0) + + for key, values := range tmpAppsecContext { + if len(values) == 0 { continue } + valueStr, err := alertcontext.TruncateContext(values, alertcontext.MaxContextValueLen) + if err != nil { + log.Warning(err.Error()) + } + meta := models.MetaItems0{ Key: key, - Value: string(valueByte), - } - alert.Meta = append(alert.Meta, &meta) - } - matchedZones := inEvt.Appsec.GetMatchedZones() - if matchedZones != nil { - valueByte, err := json.Marshal(matchedZones) - if err != nil { - log.Debugf("unable to serialize key matched_zones") - } else { - meta := models.MetaItems0{ - Key: "matched_zones", - Value: string(valueByte), - } - alert.Meta = append(alert.Meta, &meta) + Value: valueStr, } + metas = append(metas, &meta) } - alert.EventsCount = ptr.Of(int32(1)) + alert.Meta = metas + + alert.EventsCount = ptr.Of(int32(len(alert.Events))) alert.Leakspeed = ptr.Of("") alert.Scenario = ptr.Of(inEvt.Appsec.MatchedRules.GetName()) alert.ScenarioHash = ptr.Of(inEvt.Appsec.MatchedRules.GetHash()) @@ -78,15 +190,16 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) evt.Overflow.APIAlerts = []models.Alert{alert} evt.Overflow.Alert = &alert + return &evt, nil } func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types.Event, error) { evt := types.Event{} - //we might want to change this based on in-band vs out-of-band ? + // we might want to change this based on in-band vs out-of-band ? evt.Type = types.LOG evt.ExpectMode = types.LIVE - //def needs fixing + // def needs fixing evt.Stage = "s00-raw" evt.Parsed = map[string]string{ "source_ip": r.ClientIP, @@ -96,19 +209,19 @@ func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types. "req_uuid": r.Tx.ID(), "source": "crowdsec-appsec", "remediation_cmpt_ip": r.RemoteAddrNormalized, - //TBD: - //http_status - //user_agent + // TBD: + // http_status + // user_agent } evt.Line = types.Line{ Time: time.Now(), - //should we add some info like listen addr/port/path ? + // should we add some info like listen addr/port/path ? Labels: labels, Process: true, Module: "appsec", Src: "appsec", - Raw: "dummy-appsec-data", //we discard empty Line.Raw items :) + Raw: "dummy-appsec-data", // we discard empty Line.Raw items :) } evt.Appsec = types.AppsecEvent{} @@ -140,29 +253,29 @@ func LogAppsecEvent(evt *types.Event, logger *log.Entry) { "target_uri": req, }).Debugf("%s triggered non-blocking rules on %s (%d rules) [%v]", evt.Parsed["source_ip"], req, len(evt.Appsec.MatchedRules), evt.Appsec.GetRuleIDs()) } - } func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedRequest) error { - if evt == nil { - //an error was already emitted, let's not spam the logs + // an error was already emitted, let's not spam the logs return nil } if !req.Tx.IsInterrupted() { - //if the phase didn't generate an interruption, we don't have anything to add to the event + // if the phase didn't generate an interruption, we don't have anything to add to the event return nil } - //if one interruption was generated, event is good for processing :) + // if one interruption was generated, event is good for processing :) evt.Process = true if evt.Meta == nil { evt.Meta = map[string]string{} } + if evt.Parsed == nil { evt.Parsed = map[string]string{} } + if req.IsInBand { evt.Meta["appsec_interrupted"] = "true" evt.Meta["appsec_action"] = req.Tx.Interruption().Action @@ -183,9 +296,11 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR if variable.Key() != "" { key += "." + variable.Key() } + if variable.Value() == "" { continue } + for _, collectionToKeep := range r.AppsecRuntime.CompiledVariablesTracking { match := collectionToKeep.MatchString(key) if match { @@ -196,11 +311,12 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR } } } + return true }) for _, rule := range req.Tx.MatchedRules() { - if rule.Message() == "" || rule.DisruptiveAction() == "pass" || rule.DisruptiveAction() == "allow" { + if rule.Message() == "" { r.logger.Tracef("discarding rule %d (action: %s)", rule.Rule().ID(), rule.DisruptiveAction()) continue } @@ -218,11 +334,12 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR ruleNameProm := fmt.Sprintf("%d", rule.Rule().ID()) if details, ok := appsec.AppsecRulesDetails[rule.Rule().ID()]; ok { - //Only set them for custom rules, not for rules written in seclang + // Only set them for custom rules, not for rules written in seclang name = details.Name version = details.Version hash = details.Hash ruleNameProm = details.Name + r.logger.Debugf("custom rule for event, setting name: %s, version: %s, hash: %s", name, version, hash) } else { name = fmt.Sprintf("native_rule:%d", rule.Rule().ID()) @@ -231,18 +348,21 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR AppsecRuleHits.With(prometheus.Labels{"rule_name": ruleNameProm, "type": kind, "source": req.RemoteAddrNormalized, "appsec_engine": req.AppsecEngine}).Inc() matchedZones := make([]string, 0) + for _, matchData := range rule.MatchedDatas() { zone := matchData.Variable().Name() + varName := matchData.Key() if varName != "" { zone += "." + varName } + matchedZones = append(matchedZones, zone) } corazaRule := map[string]interface{}{ "id": rule.Rule().ID(), - "uri": evt.Parsed["uri"], + "uri": evt.Parsed["target_uri"], "rule_type": kind, "method": evt.Parsed["method"], "disruptive": rule.Disruptive(), @@ -263,5 +383,4 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR } return nil - } diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index 89887bef0b8..e4b6c95d77f 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -2,6 +2,7 @@ package cloudwatchacquisition import ( "context" + "errors" "fmt" "net/url" "os" @@ -43,7 +44,8 @@ var linesRead = prometheus.NewCounterVec( // CloudwatchSource is the runtime instance keeping track of N streams within 1 cloudwatch group type CloudwatchSource struct { - Config CloudwatchSourceConfiguration + metricsLevel int + Config CloudwatchSourceConfiguration /*runtime stuff*/ logger *log.Entry t *tomb.Tomb @@ -55,16 +57,16 @@ type CloudwatchSource struct { // CloudwatchSourceConfiguration allows user to define one or more streams to monitor within a cloudwatch log group type CloudwatchSourceConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` - GroupName string `yaml:"group_name"` //the group name to be monitored - StreamRegexp *string `yaml:"stream_regexp,omitempty"` //allow to filter specific streams + GroupName string `yaml:"group_name"` // the group name to be monitored + StreamRegexp *string `yaml:"stream_regexp,omitempty"` // allow to filter specific streams StreamName *string `yaml:"stream_name,omitempty"` StartTime, EndTime *time.Time `yaml:"-"` - DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` //batch size for DescribeLogStreamsPagesWithContext + DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` // batch size for DescribeLogStreamsPagesWithContext GetLogEventsPagesLimit *int64 `yaml:"getlogeventspages_limit,omitempty"` - PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` //frequency at which we poll for new streams within the log group - MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` //monitor only streams that have been updated within $duration - PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` //frequency at which we poll each stream - StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` //stop monitoring streams that haven't been updated within $duration, might be reopened later tho + PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` // frequency at which we poll for new streams within the log group + MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` // monitor only streams that have been updated within $duration + PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` // frequency at which we poll each stream + StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` // stop monitoring streams that haven't been updated within $duration, might be reopened later tho AwsApiCallTimeout *time.Duration `yaml:"aws_api_timeout,omitempty"` AwsProfile *string `yaml:"aws_profile,omitempty"` PrependCloudwatchTimestamp *bool `yaml:"prepend_cloudwatch_timestamp,omitempty"` @@ -84,7 +86,7 @@ type LogStreamTailConfig struct { logger *log.Entry ExpectMode int t tomb.Tomb - StartTime, EndTime time.Time //only used for CatMode + StartTime, EndTime time.Time // only used for CatMode } var ( @@ -109,8 +111,8 @@ func (cw *CloudwatchSource) UnmarshalConfig(yamlConfig []byte) error { return fmt.Errorf("cannot parse CloudwatchSource configuration: %w", err) } - if len(cw.Config.GroupName) == 0 { - return fmt.Errorf("group_name is mandatory for CloudwatchSource") + if cw.Config.GroupName == "" { + return errors.New("group_name is mandatory for CloudwatchSource") } if cw.Config.Mode == "" { @@ -152,12 +154,14 @@ func (cw *CloudwatchSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { err := cw.UnmarshalConfig(yamlConfig) if err != nil { return err } + cw.metricsLevel = MetricsLevel + cw.logger = logger.WithField("group", cw.Config.GroupName) cw.logger.Debugf("Starting configuration for Cloudwatch group %s", cw.Config.GroupName) @@ -172,42 +176,49 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry) erro if *cw.Config.MaxStreamAge > *cw.Config.StreamReadTimeout { cw.logger.Warningf("max_stream_age > stream_read_timeout, stream might keep being opened/closed") } + cw.logger.Tracef("aws_config_dir set to %s", *cw.Config.AwsConfigDir) if *cw.Config.AwsConfigDir != "" { _, err := os.Stat(*cw.Config.AwsConfigDir) if err != nil { cw.logger.Errorf("can't read aws_config_dir '%s' got err %s", *cw.Config.AwsConfigDir, err) - return fmt.Errorf("can't read aws_config_dir %s got err %s ", *cw.Config.AwsConfigDir, err) + return fmt.Errorf("can't read aws_config_dir %s got err %w ", *cw.Config.AwsConfigDir, err) } + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - //as aws sdk relies on $HOME, let's allow the user to override it :) + // as aws sdk relies on $HOME, let's allow the user to override it :) os.Setenv("AWS_CONFIG_FILE", fmt.Sprintf("%s/config", *cw.Config.AwsConfigDir)) os.Setenv("AWS_SHARED_CREDENTIALS_FILE", fmt.Sprintf("%s/credentials", *cw.Config.AwsConfigDir)) } else { if cw.Config.AwsRegion == nil { cw.logger.Errorf("aws_region is not specified, specify it or aws_config_dir") - return fmt.Errorf("aws_region is not specified, specify it or aws_config_dir") + return errors.New("aws_region is not specified, specify it or aws_config_dir") } + os.Setenv("AWS_REGION", *cw.Config.AwsRegion) } if err := cw.newClient(); err != nil { return err } + cw.streamIndexes = make(map[string]string) targetStream := "*" + if cw.Config.StreamRegexp != nil { if _, err := regexp.Compile(*cw.Config.StreamRegexp); err != nil { return fmt.Errorf("while compiling regexp '%s': %w", *cw.Config.StreamRegexp, err) } + targetStream = *cw.Config.StreamRegexp } else if cw.Config.StreamName != nil { targetStream = *cw.Config.StreamName } cw.logger.Infof("Adding cloudwatch group '%s' (stream:%s) to datasources", cw.Config.GroupName, targetStream) + return nil } @@ -226,26 +237,31 @@ func (cw *CloudwatchSource) newClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { cw.logger.Debugf("[testing] overloading endpoint with %s", v) cw.cwClient = cloudwatchlogs.New(sess, aws.NewConfig().WithEndpoint(v)) } else { cw.cwClient = cloudwatchlogs.New(sess) } + if cw.cwClient == nil { - return fmt.Errorf("failed to create cloudwatch client") + return errors.New("failed to create cloudwatch client") } + return nil } -func (cw *CloudwatchSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (cw *CloudwatchSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { cw.t = t monitChan := make(chan LogStreamTailConfig) + t.Go(func() error { return cw.LogStreamManager(monitChan, out) }) + return cw.WatchLogGroupForStreams(monitChan) } @@ -276,6 +292,7 @@ func (cw *CloudwatchSource) Dump() interface{} { func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig) error { cw.logger.Debugf("Starting to watch group (interval:%s)", cw.Config.PollNewStreamInterval) ticker := time.NewTicker(*cw.Config.PollNewStreamInterval) + var startFrom *string for { @@ -286,11 +303,12 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig case <-ticker.C: hasMoreStreams := true startFrom = nil + for hasMoreStreams { cw.logger.Tracef("doing the call to DescribeLogStreamsPagesWithContext") ctx := context.Background() - //there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime + // there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime err := cw.cwClient.DescribeLogStreamsPagesWithContext( ctx, &cloudwatchlogs.DescribeLogStreamsInput{ @@ -302,13 +320,14 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig }, func(page *cloudwatchlogs.DescribeLogStreamsOutput, lastPage bool) bool { cw.logger.Tracef("in helper of DescribeLogStreamsPagesWithContext") + for _, event := range page.LogStreams { startFrom = page.NextToken - //we check if the stream has been written to recently enough to be monitored + // we check if the stream has been written to recently enough to be monitored if event.LastIngestionTime != nil { - //aws uses millisecond since the epoch + // aws uses millisecond since the epoch oldest := time.Now().UTC().Add(-*cw.Config.MaxStreamAge) - //TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. + // TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. LastIngestionTime := time.Unix(0, *event.LastIngestionTime*int64(time.Millisecond)) if LastIngestionTime.Before(oldest) { cw.logger.Tracef("stop iteration, %s reached oldest age, stop (%s < %s)", *event.LogStreamName, LastIngestionTime, time.Now().UTC().Add(-*cw.Config.MaxStreamAge)) @@ -316,7 +335,7 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig return false } cw.logger.Tracef("stream %s is elligible for monitoring", *event.LogStreamName) - //the stream has been updated recently, check if we should monitor it + // the stream has been updated recently, check if we should monitor it var expectMode int if !cw.Config.UseTimeMachine { expectMode = types.LIVE @@ -354,7 +373,6 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig // LogStreamManager receives the potential streams to monitor, and starts a go routine when needed func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outChan chan types.Event) error { - cw.logger.Debugf("starting to monitor streams for %s", cw.Config.GroupName) pollDeadStreamInterval := time.NewTicker(def_PollDeadStreamInterval) @@ -381,11 +399,13 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if newStream.GroupName == stream.GroupName && newStream.StreamName == stream.StreamName { - //stream exists, but is dead, remove it from list + // stream exists, but is dead, remove it from list if !stream.t.Alive() { cw.logger.Debugf("stream %s already exists, but is dead", newStream.StreamName) cw.monitoredStreams = append(cw.monitoredStreams[:idx], cw.monitoredStreams[idx+1:]...) - openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Dec() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Dec() + } break } shouldCreate = false @@ -393,11 +413,13 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } } - //let's start watching this stream + // let's start watching this stream if shouldCreate { - openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() + } newStream.t = tomb.Tomb{} - newStream.logger = cw.logger.WithFields(log.Fields{"stream": newStream.StreamName}) + newStream.logger = cw.logger.WithField("stream", newStream.StreamName) cw.logger.Debugf("starting tail of stream %s", newStream.StreamName) newStream.t.Go(func() error { return cw.TailLogStream(&newStream, outChan) @@ -409,7 +431,9 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if !cw.monitoredStreams[idx].t.Alive() { cw.logger.Debugf("remove dead stream %s", stream.StreamName) - openedStreams.With(prometheus.Labels{"group": cw.monitoredStreams[idx].GroupName}).Dec() + if cw.metricsLevel != configuration.METRICS_NONE { + openedStreams.With(prometheus.Labels{"group": cw.monitoredStreams[idx].GroupName}).Dec() + } } else { newMonitoredStreams = append(newMonitoredStreams, stream) } @@ -437,7 +461,7 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan var startFrom *string lastReadMessage := time.Now().UTC() ticker := time.NewTicker(cfg.PollStreamInterval) - //resume at existing index if we already had + // resume at existing index if we already had streamIndexMutex.Lock() v := cw.streamIndexes[cfg.GroupName+"+"+cfg.StreamName] streamIndexMutex.Unlock() @@ -485,7 +509,9 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan cfg.logger.Warningf("cwLogToEvent error, discarded event : %s", err) } else { cfg.logger.Debugf("pushing message : %s", evt.Line.Raw) - linesRead.With(prometheus.Labels{"group": cfg.GroupName, "stream": cfg.StreamName}).Inc() + if cw.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"group": cfg.GroupName, "stream": cfg.StreamName}).Inc() + } outChan <- evt } } @@ -506,7 +532,7 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan } case <-cfg.t.Dying(): cfg.logger.Infof("logstream tail stopping") - return fmt.Errorf("killed") + return errors.New("killed") } } } @@ -517,11 +543,11 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, dsn = strings.TrimPrefix(dsn, cw.GetName()+"://") args := strings.Split(dsn, "?") if len(args) != 2 { - return fmt.Errorf("query is mandatory (at least start_date and end_date or backlog)") + return errors.New("query is mandatory (at least start_date and end_date or backlog)") } frags := strings.Split(args[0], ":") if len(frags) != 2 { - return fmt.Errorf("cloudwatch path must contain group and stream : /my/group/name:stream/name") + return errors.New("cloudwatch path must contain group and stream : /my/group/name:stream/name") } cw.Config.GroupName = frags[0] cw.Config.StreamName = &frags[1] @@ -537,7 +563,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, switch k { case "log_level": if len(v) != 1 { - return fmt.Errorf("expected zero or one value for 'log_level'") + return errors.New("expected zero or one value for 'log_level'") } lvl, err := log.ParseLevel(v[0]) if err != nil { @@ -547,32 +573,32 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, case "profile": if len(v) != 1 { - return fmt.Errorf("expected zero or one value for 'profile'") + return errors.New("expected zero or one value for 'profile'") } awsprof := v[0] cw.Config.AwsProfile = &awsprof cw.logger.Debugf("profile set to '%s'", *cw.Config.AwsProfile) case "start_date": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'start_date'") + return errors.New("expected zero or one argument for 'start_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, startDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.StartTime = &startDate case "end_date": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'end_date'") + return errors.New("expected zero or one argument for 'end_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, endDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.EndTime = &endDate case "backlog": if len(v) != 1 { - return fmt.Errorf("expected zero or one argument for 'backlog'") + return errors.New("expected zero or one argument for 'backlog'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported duration, err := time.ParseDuration(v[0]) if err != nil { return fmt.Errorf("unable to parse '%s' as duration: %w", v[0], err) @@ -595,10 +621,10 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, } if cw.Config.StreamName == nil || cw.Config.GroupName == "" { - return fmt.Errorf("missing stream or group name") + return errors.New("missing stream or group name") } if cw.Config.StartTime == nil || cw.Config.EndTime == nil { - return fmt.Errorf("start_date and end_date or backlog are mandatory in one-shot mode") + return errors.New("start_date and end_date or backlog are mandatory in one-shot mode") } cw.Config.Mode = configuration.CAT_MODE @@ -608,7 +634,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, } func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - //StreamName string, Start time.Time, End time.Time + // StreamName string, Start time.Time, End time.Time config := LogStreamTailConfig{ GroupName: cw.Config.GroupName, StreamName: *cw.Config.StreamName, @@ -627,7 +653,7 @@ func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tom func (cw *CloudwatchSource) CatLogStream(cfg *LogStreamTailConfig, outChan chan types.Event) error { var startFrom *string - var head = true + head := true /*convert the times*/ startTime := cfg.StartTime.UTC().Unix() * 1000 endTime := cfg.EndTime.UTC().Unix() * 1000 @@ -689,7 +715,7 @@ func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) l := types.Line{} evt := types.Event{} if log.Message == nil { - return evt, fmt.Errorf("nil message") + return evt, errors.New("nil message") } msg := *log.Message if cfg.PrependCloudwatchTimestamp != nil && *cfg.PrependCloudwatchTimestamp { diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index 5d64755e2e9..d62c3f6e3dd 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,6 +1,8 @@ package cloudwatchacquisition import ( + "context" + "errors" "fmt" "net" "os" @@ -9,14 +11,16 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/cloudwatchlogs" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) /* @@ -31,6 +35,7 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { input := &cloudwatchlogs.DescribeLogGroupsInput{} result, err := cw.cwClient.DescribeLogGroups(input) require.NoError(t, err) + for _, group := range result.LogGroups { _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ LogGroupName: group.LogGroupName, @@ -42,14 +47,14 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { func checkForLocalStackAvailability() error { v := os.Getenv("AWS_ENDPOINT_FORCE") if v == "" { - return fmt.Errorf("missing aws endpoint for tests : AWS_ENDPOINT_FORCE") + return errors.New("missing aws endpoint for tests : AWS_ENDPOINT_FORCE") } v = strings.TrimPrefix(v, "http://") _, err := net.Dial("tcp", v) if err != nil { - return fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err) + return fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) } return nil @@ -59,18 +64,22 @@ func TestMain(m *testing.M) { if runtime.GOOS == "windows" { os.Exit(0) } + if err := checkForLocalStackAvailability(); err != nil { log.Fatalf("local stack error : %s", err) } + def_PollNewStreamInterval = 1 * time.Second def_PollStreamInterval = 1 * time.Second def_StreamReadTimeout = 10 * time.Second def_MaxStreamAge = 5 * time.Second def_PollDeadStreamInterval = 5 * time.Second + os.Exit(m.Run()) } func TestWatchLogGroupForStreams(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -421,13 +430,12 @@ stream_name: test_stream`), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) dbgLogger.Infof("starting test") cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger) + err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { @@ -445,7 +453,7 @@ stream_name: test_stream`), dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(out, &actmb) + err := cw.StreamingAcquisition(ctx, out, &actmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) return nil @@ -501,7 +509,6 @@ stream_name: test_stream`), if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) @@ -511,6 +518,7 @@ stream_name: test_stream`), } func TestConfiguration(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -554,12 +562,11 @@ stream_name: test_stream`), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) cw := CloudwatchSource{} - err := cw.Configure(tc.config, dbgLogger) + err := cw.Configure(tc.config, dbgLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedCfgErr) if tc.expectedCfgErr != "" { return @@ -570,7 +577,7 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(out, &tmb) + err = cw.StreamingAcquisition(ctx, out, &tmb) case "cat": err = cw.OneShotAcquisition(out, &tmb) } @@ -619,7 +626,6 @@ func TestConfigureByDSN(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) @@ -741,7 +747,6 @@ func TestOneShotAcquisition(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { dbgLogger := log.New().WithField("test", tc.name) dbgLogger.Logger.SetLevel(log.DebugLevel) @@ -799,7 +804,6 @@ func TestOneShotAcquisition(t *testing.T) { if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 60f1100b35a..874b1556fd5 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -3,6 +3,7 @@ package dockeracquisition import ( "bufio" "context" + "errors" "fmt" "net/url" "regexp" @@ -41,11 +42,12 @@ type DockerConfiguration struct { ContainerID []string `yaml:"container_id"` ContainerNameRegexp []string `yaml:"container_name_regexp"` ContainerIDRegexp []string `yaml:"container_id_regexp"` - ForceInotify bool `yaml:"force_inotify"` + UseContainerLabels bool `yaml:"use_container_labels"` configuration.DataSourceCommonCfg `yaml:",inline"` } type DockerSource struct { + metricsLevel int Config DockerConfiguration runningContainerState map[string]*ContainerConfig compiledContainerName []*regexp.Regexp @@ -86,8 +88,12 @@ func (d *DockerSource) UnmarshalConfig(yamlConfig []byte) error { d.logger.Tracef("DockerAcquisition configuration: %+v", d.Config) } - if len(d.Config.ContainerName) == 0 && len(d.Config.ContainerID) == 0 && len(d.Config.ContainerIDRegexp) == 0 && len(d.Config.ContainerNameRegexp) == 0 { - return fmt.Errorf("no containers names or containers ID configuration provided") + if len(d.Config.ContainerName) == 0 && len(d.Config.ContainerID) == 0 && len(d.Config.ContainerIDRegexp) == 0 && len(d.Config.ContainerNameRegexp) == 0 && !d.Config.UseContainerLabels { + return errors.New("no containers names or containers ID configuration provided") + } + + if d.Config.UseContainerLabels && (len(d.Config.ContainerName) > 0 || len(d.Config.ContainerID) > 0 || len(d.Config.ContainerIDRegexp) > 0 || len(d.Config.ContainerNameRegexp) > 0) { + return errors.New("use_container_labels and container_name, container_id, container_id_regexp, container_name_regexp are mutually exclusive") } d.CheckIntervalDuration, err = time.ParseDuration(d.Config.CheckInterval) @@ -128,9 +134,9 @@ func (d *DockerSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (d *DockerSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (d *DockerSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { d.logger = logger - + d.metricsLevel = MetricsLevel err := d.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -220,7 +226,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg switch k { case "log_level": if len(v) != 1 { - return fmt.Errorf("only one 'log_level' parameters is required, not many") + return errors.New("only one 'log_level' parameters is required, not many") } lvl, err := log.ParseLevel(v[0]) if err != nil { @@ -229,17 +235,17 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.logger.Logger.SetLevel(lvl) case "until": if len(v) != 1 { - return fmt.Errorf("only one 'until' parameters is required, not many") + return errors.New("only one 'until' parameters is required, not many") } d.containerLogsOptions.Until = v[0] case "since": if len(v) != 1 { - return fmt.Errorf("only one 'since' parameters is required, not many") + return errors.New("only one 'since' parameters is required, not many") } d.containerLogsOptions.Since = v[0] case "follow_stdout": if len(v) != 1 { - return fmt.Errorf("only one 'follow_stdout' parameters is required, not many") + return errors.New("only one 'follow_stdout' parameters is required, not many") } followStdout, err := strconv.ParseBool(v[0]) if err != nil { @@ -249,7 +255,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.containerLogsOptions.ShowStdout = followStdout case "follow_stderr": if len(v) != 1 { - return fmt.Errorf("only one 'follow_stderr' parameters is required, not many") + return errors.New("only one 'follow_stderr' parameters is required, not many") } followStdErr, err := strconv.ParseBool(v[0]) if err != nil { @@ -259,7 +265,7 @@ func (d *DockerSource) ConfigureByDSN(dsn string, labels map[string]string, logg d.containerLogsOptions.ShowStderr = followStdErr case "docker_host": if len(v) != 1 { - return fmt.Errorf("only one 'docker_host' parameters is required, not many") + return errors.New("only one 'docker_host' parameters is required, not many") } if err := client.WithHost(v[0])(dockerClient); err != nil { return err @@ -292,7 +298,7 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er d.logger.Debugf("container with id %s is already being read from", container.ID) continue } - if containerConfig, ok := d.EvalContainer(container); ok { + if containerConfig := d.EvalContainer(container); containerConfig != nil { d.logger.Infof("reading logs from container %s", containerConfig.Name) d.logger.Debugf("logs options: %+v", *d.containerLogsOptions) dockerReader, err := d.Client.ContainerLogs(context.Background(), containerConfig.ID, *d.containerLogsOptions) @@ -325,7 +331,9 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er l.Src = containerConfig.Name l.Process = true l.Module = d.GetName() - linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() + if d.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() + } evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} out <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) @@ -372,41 +380,88 @@ func (d *DockerSource) getContainerTTY(containerId string) bool { return containerDetails.Config.Tty } -func (d *DockerSource) EvalContainer(container dockerTypes.Container) (*ContainerConfig, bool) { +func (d *DockerSource) getContainerLabels(containerId string) map[string]interface{} { + containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId) + if err != nil { + return map[string]interface{}{} + } + return parseLabels(containerDetails.Config.Labels) +} + +func (d *DockerSource) EvalContainer(container dockerTypes.Container) *ContainerConfig { for _, containerID := range d.Config.ContainerID { if containerID == container.ID { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } for _, containerName := range d.Config.ContainerName { for _, name := range container.Names { - if strings.HasPrefix(name, "/") && len(name) > 0 { + if strings.HasPrefix(name, "/") && name != "" { name = name[1:] } if name == containerName { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } - } for _, cont := range d.compiledContainerID { if matched := cont.MatchString(container.ID); matched { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } for _, cont := range d.compiledContainerName { for _, name := range container.Names { if matched := cont.MatchString(name); matched { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}, true + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} } } + } + if d.Config.UseContainerLabels { + parsedLabels := d.getContainerLabels(container.ID) + if len(parsedLabels) == 0 { + d.logger.Tracef("container has no 'crowdsec' labels set, ignoring container: %s", container.ID) + return nil + } + if _, ok := parsedLabels["enable"]; !ok { + d.logger.Errorf("container has 'crowdsec' labels set but no 'crowdsec.enable' key found") + return nil + } + enable, ok := parsedLabels["enable"].(string) + if !ok { + d.logger.Error("container has 'crowdsec.enable' label set but it's not a string") + return nil + } + if strings.ToLower(enable) != "true" { + d.logger.Debugf("container has 'crowdsec.enable' label not set to true ignoring container: %s", container.ID) + return nil + } + if _, ok = parsedLabels["labels"]; !ok { + d.logger.Error("container has 'crowdsec.enable' label set to true but no 'labels' keys found") + return nil + } + labelsTypeCast, ok := parsedLabels["labels"].(map[string]interface{}) + if !ok { + d.logger.Error("container has 'crowdsec.enable' label set to true but 'labels' is not a map") + return nil + } + d.logger.Debugf("container labels %+v", labelsTypeCast) + labels := make(map[string]string) + for k, v := range labelsTypeCast { + if v, ok := v.(string); ok { + log.Debugf("label %s is a string with value %s", k, v) + labels[k] = v + continue + } + d.logger.Errorf("label %s is not a string", k) + } + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(container.ID)} } - return &ContainerConfig{}, false + return nil } func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error { @@ -446,7 +501,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha if _, ok := d.runningContainerState[container.ID]; ok { continue } - if containerConfig, ok := d.EvalContainer(container); ok { + if containerConfig := d.EvalContainer(container); containerConfig != nil { monitChan <- containerConfig } } @@ -463,7 +518,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha } } -func (d *DockerSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.t = t monitChan := make(chan *ContainerConfig) deleteChan := make(chan *ContainerConfig) @@ -519,7 +574,7 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types } l := types.Line{} l.Raw = line - l.Labels = d.Config.Labels + l.Labels = container.Labels l.Time = time.Now().UTC() l.Src = container.Name l.Process = true @@ -534,11 +589,11 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types outChan <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) case <-readerTomb.Dying(): - //This case is to handle temporarily losing the connection to the docker socket - //The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) + // This case is to handle temporarily losing the connection to the docker socket + // The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) d.logger.Debugf("readerTomb dying for container %s, removing it from runningContainerState", container.Name) deleteChan <- container - //Also reset the Since to avoid re-reading logs + // Also reset the Since to avoid re-reading logs d.Config.Since = time.Now().UTC().Format(time.RFC3339) d.containerLogsOptions.Since = d.Config.Since return nil @@ -553,7 +608,7 @@ func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan * case newContainer := <-in: if _, ok := d.runningContainerState[newContainer.ID]; !ok { newContainer.t = &tomb.Tomb{} - newContainer.logger = d.logger.WithFields(log.Fields{"container_name": newContainer.Name}) + newContainer.logger = d.logger.WithField("container_name", newContainer.Name) newContainer.t.Go(func() error { return d.TailDocker(newContainer, outChan, deleteChan) }) diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index c4d23168a37..e394c9cbe79 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -11,16 +11,17 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" dockerTypes "github.com/docker/docker/api/types" dockerContainer "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) const testContainerName = "docker_test" @@ -54,13 +55,11 @@ container_name: }, } - subLogger := log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger := log.WithField("type", "docker") for _, test := range tests { f := DockerSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } @@ -107,9 +106,7 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "", }, } - subLogger := log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger := log.WithField("type", "docker") for _, test := range tests { f := DockerSource{} @@ -123,6 +120,7 @@ type mockDockerCli struct { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") @@ -162,19 +160,15 @@ container_name_regexp: for _, ts := range tests { var ( - logger *log.Logger + logger *log.Logger subLogger *log.Entry ) if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = logger.WithField("type", "docker") } else { - subLogger = log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = log.WithField("type", "docker") } readLogs = false @@ -182,7 +176,7 @@ container_name_regexp: out := make(chan types.Event) dockerSource := DockerSource{} - err := dockerSource.Configure([]byte(ts.config), subLogger) + err := dockerSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -192,7 +186,7 @@ container_name_regexp: readerTomb := &tomb.Tomb{} streamTomb := tomb.Tomb{} streamTomb.Go(func() error { - return dockerSource.StreamingAcquisition(out, &dockerTomb) + return dockerSource.StreamingAcquisition(ctx, out, &dockerTomb) }) readerTomb.Go(func() error { time.Sleep(1 * time.Second) @@ -227,7 +221,7 @@ container_name_regexp: } func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) { - if readLogs == true { + if readLogs { return []dockerTypes.Container{}, nil } @@ -242,7 +236,7 @@ func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes } func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, options dockerTypes.ContainerLogsOptions) (io.ReadCloser, error) { - if readLogs == true { + if readLogs { return io.NopCloser(strings.NewReader("")), nil } @@ -252,7 +246,7 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o for _, line := range data { startLineByte := make([]byte, 8) - binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream + binary.LittleEndian.PutUint32(startLineByte, 1) // stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } @@ -304,19 +298,15 @@ func TestOneShot(t *testing.T) { for _, ts := range tests { var ( subLogger *log.Entry - logger *log.Logger + logger *log.Logger ) if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = logger.WithField("type", "docker") } else { log.SetLevel(ts.logLevel) - subLogger = log.WithFields(log.Fields{ - "type": "docker", - }) + subLogger = log.WithField("type", "docker") } readLogs = false @@ -340,3 +330,54 @@ func TestOneShot(t *testing.T) { } } } + +func TestParseLabels(t *testing.T) { + tests := []struct { + name string + labels map[string]string + expected map[string]interface{} + }{ + { + name: "bad label", + labels: map[string]string{"crowdsecfoo": "bar"}, + expected: map[string]interface{}{}, + }, + { + name: "simple label", + labels: map[string]string{"crowdsec.bar": "baz"}, + expected: map[string]interface{}{"bar": "baz"}, + }, + { + name: "multiple simple labels", + labels: map[string]string{"crowdsec.bar": "baz", "crowdsec.foo": "bar"}, + expected: map[string]interface{}{"bar": "baz", "foo": "bar"}, + }, + { + name: "multiple simple labels 2", + labels: map[string]string{"crowdsec.bar": "baz", "bla": "foo"}, + expected: map[string]interface{}{"bar": "baz"}, + }, + { + name: "end with dot", + labels: map[string]string{"crowdsec.bar.": "baz"}, + expected: map[string]interface{}{}, + }, + { + name: "consecutive dots", + labels: map[string]string{"crowdsec......bar": "baz"}, + expected: map[string]interface{}{}, + }, + { + name: "crowdsec labels", + labels: map[string]string{"crowdsec.labels.type": "nginx"}, + expected: map[string]interface{}{"labels": map[string]interface{}{"type": "nginx"}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + labels := parseLabels(test.labels) + assert.Equal(t, test.expected, labels) + }) + } +} diff --git a/pkg/acquisition/modules/docker/utils.go b/pkg/acquisition/modules/docker/utils.go new file mode 100644 index 00000000000..6a0d494097f --- /dev/null +++ b/pkg/acquisition/modules/docker/utils.go @@ -0,0 +1,38 @@ +package dockeracquisition + +import ( + "strings" +) + +func parseLabels(labels map[string]string) map[string]interface{} { + result := make(map[string]interface{}) + for key, value := range labels { + parseKeyToMap(result, key, value) + } + return result +} + +func parseKeyToMap(m map[string]interface{}, key string, value string) { + if !strings.HasPrefix(key, "crowdsec") { + return + } + parts := strings.Split(key, ".") + + if len(parts) < 2 || parts[0] != "crowdsec" { + return + } + + for i := range parts { + if parts[i] == "" { + return + } + } + + for i := 1; i < len(parts)-1; i++ { + if _, ok := m[parts[i]]; !ok { + m[parts[i]] = make(map[string]interface{}) + } + m = m[parts[i]].(map[string]interface{}) + } + m[parts[len(parts)-1]] = value +} diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index 4ea9466d457..2d2df3ff4d4 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -3,6 +3,8 @@ package fileacquisition import ( "bufio" "compress/gzip" + "context" + "errors" "fmt" "io" "net/url" @@ -11,11 +13,11 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/fsnotify/fsnotify" "github.com/nxadm/tail" - "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" @@ -38,13 +40,14 @@ type FileConfiguration struct { Filenames []string ExcludeRegexps []string `yaml:"exclude_regexps"` Filename string - ForceInotify bool `yaml:"force_inotify"` - MaxBufferSize int `yaml:"max_buffer_size"` - PollWithoutInotify bool `yaml:"poll_without_inotify"` + ForceInotify bool `yaml:"force_inotify"` + MaxBufferSize int `yaml:"max_buffer_size"` + PollWithoutInotify *bool `yaml:"poll_without_inotify"` configuration.DataSourceCommonCfg `yaml:",inline"` } type FileSource struct { + metricsLevel int config FileConfiguration watcher *fsnotify.Watcher watchedDirectories map[string]bool @@ -52,6 +55,7 @@ type FileSource struct { logger *log.Entry files []string exclude_regexps []*regexp.Regexp + tailMapMutex *sync.RWMutex } func (f *FileSource) GetUuid() string { @@ -60,6 +64,7 @@ func (f *FileSource) GetUuid() string { func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { f.config = FileConfiguration{} + err := yaml.UnmarshalStrict(yamlConfig, &f.config) if err != nil { return fmt.Errorf("cannot parse FileAcquisition configuration: %w", err) @@ -69,12 +74,12 @@ func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { f.logger.Tracef("FileAcquisition configuration: %+v", f.config) } - if len(f.config.Filename) != 0 { + if f.config.Filename != "" { f.config.Filenames = append(f.config.Filenames, f.config.Filename) } if len(f.config.Filenames) == 0 { - return fmt.Errorf("no filename or filenames configuration provided") + return errors.New("no filename or filenames configuration provided") } if f.config.Mode == "" { @@ -90,14 +95,16 @@ func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { if err != nil { return fmt.Errorf("could not compile regexp %s: %w", exclude, err) } + f.exclude_regexps = append(f.exclude_regexps, re) } return nil } -func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { f.logger = logger + f.metricsLevel = MetricsLevel err := f.UnmarshalConfig(yamlConfig) if err != nil { @@ -105,6 +112,7 @@ func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { } f.watchedDirectories = make(map[string]bool) + f.tailMapMutex = &sync.RWMutex{} f.tails = make(map[string]bool) f.watcher, err = fsnotify.NewWatcher() @@ -118,56 +126,68 @@ func (f *FileSource) Configure(yamlConfig []byte, logger *log.Entry) error { if f.config.ForceInotify { directory := filepath.Dir(pattern) f.logger.Infof("Force add watch on %s", directory) + if !f.watchedDirectories[directory] { err = f.watcher.Add(directory) if err != nil { f.logger.Errorf("Could not create watch on directory %s : %s", directory, err) continue } + f.watchedDirectories[directory] = true } } + files, err := filepath.Glob(pattern) if err != nil { return fmt.Errorf("glob failure: %w", err) } + if len(files) == 0 { f.logger.Warnf("No matching files for pattern %s", pattern) continue } - for _, file := range files { - //check if file is excluded + for _, file := range files { + // check if file is excluded excluded := false + for _, pattern := range f.exclude_regexps { if pattern.MatchString(file) { excluded = true + f.logger.Infof("Skipping file %s as it matches exclude pattern %s", file, pattern) + break } } + if excluded { continue } - if files[0] != pattern && f.config.Mode == configuration.TAIL_MODE { //we have a glob pattern + + if files[0] != pattern && f.config.Mode == configuration.TAIL_MODE { // we have a glob pattern directory := filepath.Dir(file) f.logger.Debugf("Will add watch to directory: %s", directory) - if !f.watchedDirectories[directory] { + if !f.watchedDirectories[directory] { err = f.watcher.Add(directory) if err != nil { f.logger.Errorf("Could not create watch on directory %s : %s", directory, err) continue } + f.watchedDirectories[directory] = true } else { f.logger.Debugf("Watch for directory %s already exists", directory) } } + f.logger.Infof("Adding file %s to datasources", file) f.files = append(f.files, file) } } + return nil } @@ -183,34 +203,39 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger args := strings.Split(dsn, "?") - if len(args[0]) == 0 { - return fmt.Errorf("empty file:// DSN") + if args[0] == "" { + return errors.New("empty file:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse file args: %w", err) } + for key, value := range params { switch key { case "log_level": if len(value) != 1 { return errors.New("expected zero or one value for 'log_level'") } + lvl, err := log.ParseLevel(value[0]) if err != nil { return fmt.Errorf("unknown level %s: %w", value[0], err) } + f.logger.Logger.SetLevel(lvl) case "max_buffer_size": if len(value) != 1 { return errors.New("expected zero or one value for 'max_buffer_size'") } + maxBufferSize, err := strconv.Atoi(value[0]) if err != nil { return fmt.Errorf("could not parse max_buffer_size %s: %w", value[0], err) } + f.config.MaxBufferSize = maxBufferSize default: return fmt.Errorf("unknown parameter %s", key) @@ -223,6 +248,7 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger f.config.UniqueId = uuid f.logger.Debugf("Will try pattern %s", args[0]) + files, err := filepath.Glob(args[0]) if err != nil { return fmt.Errorf("glob failure: %w", err) @@ -240,6 +266,7 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger f.logger.Infof("Adding file %s to filelist", file) f.files = append(f.files, file) } + return nil } @@ -255,22 +282,26 @@ func (f *FileSource) SupportedModes() []string { // OneShotAcquisition reads a set of file and returns when done func (f *FileSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("In oneshot") + for _, file := range f.files { fi, err := os.Stat(file) if err != nil { return fmt.Errorf("could not stat file %s : %w", file, err) } + if fi.IsDir() { f.logger.Warnf("%s is a directory, ignoring it.", file) continue } + f.logger.Infof("reading %s at once", file) + err = f.readFile(file, out, t) if err != nil { return err } - } + return nil } @@ -290,32 +321,38 @@ func (f *FileSource) CanRun() error { return nil } -func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("Starting live acquisition") t.Go(func() error { return f.monitorNewFiles(out, t) }) + for _, file := range f.files { - //before opening the file, check if we need to specifically avoid it. (XXX) + // before opening the file, check if we need to specifically avoid it. (XXX) skip := false + for _, pattern := range f.exclude_regexps { if pattern.MatchString(file) { f.logger.Infof("file %s matches exclusion pattern %s, skipping", file, pattern.String()) + skip = true + break } } + if skip { continue } - //cf. https://github.com/crowdsecurity/crowdsec/issues/1168 - //do not rely on stat, reclose file immediately as it's opened by Tail + // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 + // do not rely on stat, reclose file immediately as it's opened by Tail fd, err := os.Open(file) if err != nil { f.logger.Errorf("unable to read %s : %s", file, err) continue } + if err := fd.Close(); err != nil { f.logger.Errorf("unable to close %s : %s", file, err) continue @@ -325,22 +362,54 @@ func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er if err != nil { return fmt.Errorf("could not stat file %s : %w", file, err) } + if fi.IsDir() { f.logger.Warnf("%s is a directory, ignoring it.", file) continue } - tail, err := tail.TailFile(file, tail.Config{ReOpen: true, Follow: true, Poll: f.config.PollWithoutInotify, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekEnd}, Logger: log.NewEntry(log.StandardLogger())}) + pollFile := false + if f.config.PollWithoutInotify != nil { + pollFile = *f.config.PollWithoutInotify + } else { + networkFS, fsType, err := types.IsNetworkFS(file) + if err != nil { + f.logger.Warningf("Could not get fs type for %s : %s", file, err) + } + + f.logger.Debugf("fs for %s is network: %t (%s)", file, networkFS, fsType) + + if networkFS { + f.logger.Warnf("Disabling inotify polling on %s as it is on a network share. You can manually set poll_without_inotify to true to make this message disappear, or to false to enforce inotify poll", file) + pollFile = true + } + } + + filink, err := os.Lstat(file) + if err != nil { + f.logger.Errorf("Could not lstat() new file %s, ignoring it : %s", file, err) + continue + } + + if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { + f.logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", file) + } + + tail, err := tail.TailFile(file, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekEnd}, Logger: log.NewEntry(log.StandardLogger())}) if err != nil { f.logger.Errorf("Could not start tailing file %s : %s", file, err) continue } + + f.tailMapMutex.Lock() f.tails[file] = true + f.tailMapMutex.Unlock() t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/file/live/fsnotify") return f.tailFile(out, t, tail) }) } + return nil } @@ -350,6 +419,7 @@ func (f *FileSource) Dump() interface{} { func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { logger := f.logger.WithField("goroutine", "inotify") + for { select { case event, ok := <-f.watcher.Events: @@ -357,84 +427,134 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { return nil } - if event.Op&fsnotify.Create == fsnotify.Create { - fi, err := os.Stat(event.Name) + if event.Op&fsnotify.Create != fsnotify.Create { + continue + } + + fi, err := os.Stat(event.Name) + if err != nil { + logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) + continue + } + + if fi.IsDir() { + continue + } + + logger.Debugf("Detected new file %s", event.Name) + + matched := false + + for _, pattern := range f.config.Filenames { + logger.Debugf("Matching %s with %s", pattern, event.Name) + + matched, err = filepath.Match(pattern, event.Name) if err != nil { - logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) - continue - } - if fi.IsDir() { - continue - } - logger.Debugf("Detected new file %s", event.Name) - matched := false - for _, pattern := range f.config.Filenames { - logger.Debugf("Matching %s with %s", pattern, event.Name) - matched, err = filepath.Match(pattern, event.Name) - if err != nil { - logger.Errorf("Could not match pattern : %s", err) - continue - } - if matched { - logger.Debugf("Matched %s with %s", pattern, event.Name) - break - } - } - if !matched { + logger.Errorf("Could not match pattern : %s", err) continue } - //before opening the file, check if we need to specifically avoid it. (XXX) - skip := false - for _, pattern := range f.exclude_regexps { - if pattern.MatchString(event.Name) { - f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) - skip = true - break - } - } - if skip { - continue + if matched { + logger.Debugf("Matched %s with %s", pattern, event.Name) + break } + } + + if !matched { + continue + } + + // before opening the file, check if we need to specifically avoid it. (XXX) + skip := false + + for _, pattern := range f.exclude_regexps { + if pattern.MatchString(event.Name) { + f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) + + skip = true - if f.tails[event.Name] { - //we already have a tail on it, do not start a new one - logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) break } - //cf. https://github.com/crowdsecurity/crowdsec/issues/1168 - //do not rely on stat, reclose file immediately as it's opened by Tail - fd, err := os.Open(event.Name) + } + + if skip { + continue + } + + f.tailMapMutex.RLock() + if f.tails[event.Name] { + f.tailMapMutex.RUnlock() + // we already have a tail on it, do not start a new one + logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) + + break + } + f.tailMapMutex.RUnlock() + // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 + // do not rely on stat, reclose file immediately as it's opened by Tail + fd, err := os.Open(event.Name) + if err != nil { + f.logger.Errorf("unable to read %s : %s", event.Name, err) + continue + } + + if err = fd.Close(); err != nil { + f.logger.Errorf("unable to close %s : %s", event.Name, err) + continue + } + + pollFile := false + if f.config.PollWithoutInotify != nil { + pollFile = *f.config.PollWithoutInotify + } else { + networkFS, fsType, err := types.IsNetworkFS(event.Name) if err != nil { - f.logger.Errorf("unable to read %s : %s", event.Name, err) - continue - } - if err := fd.Close(); err != nil { - f.logger.Errorf("unable to close %s : %s", event.Name, err) - continue + f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err) } - //Slightly different parameters for Location, as we want to read the first lines of the newly created file - tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: f.config.PollWithoutInotify, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) - if err != nil { - logger.Errorf("Could not start tailing file %s : %s", event.Name, err) - break + + f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType) + + if networkFS { + pollFile = true } - f.tails[event.Name] = true - t.Go(func() error { - defer trace.CatchPanic("crowdsec/acquis/tailfile") - return f.tailFile(out, t, tail) - }) } + + filink, err := os.Lstat(event.Name) + if err != nil { + logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err) + continue + } + + if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { + logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name) + } + + // Slightly different parameters for Location, as we want to read the first lines of the newly created file + tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) + if err != nil { + logger.Errorf("Could not start tailing file %s : %s", event.Name, err) + break + } + + f.tailMapMutex.Lock() + f.tails[event.Name] = true + f.tailMapMutex.Unlock() + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/tailfile") + return f.tailFile(out, t, tail) + }) case err, ok := <-f.watcher.Errors: if !ok { return nil } + logger.Errorf("Error while monitoring folder: %s", err) case <-t.Dying(): err := f.watcher.Close() if err != nil { return fmt.Errorf("could not remove all inotify watches: %w", err) } + return nil } } @@ -443,46 +563,62 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tail) error { logger := f.logger.WithField("tail", tail.Filename) logger.Debugf("-> Starting tail of %s", tail.Filename) + for { select { case <-t.Dying(): logger.Infof("File datasource %s stopping", tail.Filename) + if err := tail.Stop(); err != nil { f.logger.Errorf("error in stop : %s", err) return err } + return nil - case <-tail.Dying(): //our tailer is dying - err := tail.Err() + case <-tail.Dying(): // our tailer is dying errMsg := fmt.Sprintf("file reader of %s died", tail.Filename) + + err := tail.Err() if err != nil { errMsg = fmt.Sprintf(errMsg+" : %s", err) } - logger.Warningf(errMsg) - t.Kill(fmt.Errorf(errMsg)) - return fmt.Errorf(errMsg) + + logger.Warning(errMsg) + + return nil case line := <-tail.Lines: if line == nil { logger.Warningf("tail for %s is empty", tail.Filename) continue } + if line.Err != nil { logger.Warningf("fetch error : %v", line.Err) return line.Err } - if line.Text == "" { //skip empty lines + + if line.Text == "" { // skip empty lines continue } - linesRead.With(prometheus.Labels{"source": tail.Filename}).Inc() + + if f.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": tail.Filename}).Inc() + } + + src := tail.Filename + if f.metricsLevel == configuration.METRICS_AGGREGATE { + src = filepath.Base(tail.Filename) + } + l := types.Line{ Raw: trimLine(line.Text), Labels: f.config.Labels, Time: line.Time, - Src: tail.Filename, + Src: src, Process: true, Module: f.GetName(), } - //we're tailing, it must be real time logs + // we're tailing, it must be real time logs logger.Debugf("pushing %+v", l) expectMode := types.LIVE @@ -496,12 +632,14 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tomb) error { var scanner *bufio.Scanner + logger := f.logger.WithField("oneshot", filename) - fd, err := os.Open(filename) + fd, err := os.Open(filename) if err != nil { return fmt.Errorf("failed opening %s: %w", filename, err) } + defer fd.Close() if strings.HasSuffix(filename, ".gz") { @@ -510,17 +648,20 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom logger.Errorf("Failed to read gz file: %s", err) return fmt.Errorf("failed to read gz %s: %w", filename, err) } + defer gz.Close() scanner = bufio.NewScanner(gz) - } else { scanner = bufio.NewScanner(fd) } + scanner.Split(bufio.ScanLines) + if f.config.MaxBufferSize > 0 { buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, f.config.MaxBufferSize) } + for scanner.Scan() { select { case <-t.Dying(): @@ -530,6 +671,7 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom if scanner.Text() == "" { continue } + l := types.Line{ Raw: scanner.Text(), Time: time.Now().UTC(), @@ -541,15 +683,19 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom logger.Debugf("line %s", l.Raw) linesRead.With(prometheus.Labels{"source": filename}).Inc() - //we're reading logs at once, it must be time-machine buckets + // we're reading logs at once, it must be time-machine buckets out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} } } + if err := scanner.Err(); err != nil { logger.Errorf("Error while reading file: %s", err) t.Kill(err) + return err } + t.Kill(nil) + return nil } diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index 410beb4bc85..3db0042ba2f 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,6 +1,7 @@ package fileacquisition_test import ( + "context" "fmt" "os" "runtime" @@ -15,6 +16,7 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -48,15 +50,12 @@ exclude_regexps: ["as[a-$d"]`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "file", - }) + subLogger := log.WithField("type", "file") for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := fileacquisition.FileSource{} - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } @@ -90,12 +89,9 @@ func TestConfigureDSN(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "file", - }) + subLogger := log.WithField("type", "file") for _, tc := range tests { - tc := tc t.Run(tc.dsn, func(t *testing.T) { f := fileacquisition.FileSource{} err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -205,14 +201,11 @@ filename: test_files/test_delete.log`, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := test.NewNullLogger() logger.SetLevel(tc.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") tomb := tomb.Tomb{} out := make(chan types.Event, 100) @@ -222,7 +215,7 @@ filename: test_files/test_delete.log`, tc.setup() } - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.expectedConfigErr) if tc.expectedConfigErr != "" { return @@ -251,6 +244,7 @@ filename: test_files/test_delete.log`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() permDeniedFile := "/etc/shadow" permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" testPattern := "test_files/*.log" @@ -366,14 +360,11 @@ force_inotify: true`, testPattern), } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := test.NewNullLogger() logger.SetLevel(tc.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") tomb := tomb.Tomb{} out := make(chan types.Event) @@ -384,7 +375,7 @@ force_inotify: true`, testPattern), tc.setup() } - err := f.Configure([]byte(tc.config), subLogger) + err := f.Configure([]byte(tc.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) if tc.afterConfigure != nil { @@ -405,18 +396,18 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(out, &tomb) + err = f.StreamingAcquisition(ctx, out, &tomb) cstest.RequireErrorContains(t, err, tc.expectedErr) if tc.expectedLines != 0 { fd, err := os.Create("test_files/stream.log") require.NoError(t, err, "could not create test file") - for i := 0; i < 5; i++ { + for i := range 5 { _, err = fmt.Fprintf(fd, "%d\n", i) if err != nil { - t.Fatalf("could not write test file : %s", err) os.Remove("test_files/stream.log") + t.Fatalf("could not write test file : %s", err) } } @@ -450,12 +441,10 @@ func TestExclusion(t *testing.T) { exclude_regexps: ["\\.gz$"]` logger, hook := test.NewNullLogger() // logger.SetLevel(ts.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) + subLogger := logger.WithField("type", "file") f := fileacquisition.FileSource{} - if err := f.Configure([]byte(config), subLogger); err != nil { + if err := f.Configure([]byte(config), subLogger, configuration.METRICS_NONE); err != nil { subLogger.Fatalf("unexpected error: %s", err) } diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index 55091a7b5eb..b9cda54a472 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -3,6 +3,7 @@ package journalctlacquisition import ( "bufio" "context" + "errors" "fmt" "net/url" "os/exec" @@ -26,10 +27,11 @@ type JournalCtlConfiguration struct { } type JournalCtlSource struct { - config JournalCtlConfiguration - logger *log.Entry - src string - args []string + metricsLevel int + config JournalCtlConfiguration + logger *log.Entry + src string + args []string } const journalctlCmd string = "journalctl" @@ -97,7 +99,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err if stdoutscanner == nil { cancel() cmd.Wait() - return fmt.Errorf("failed to create stdout scanner") + return errors.New("failed to create stdout scanner") } stderrScanner := bufio.NewScanner(stderr) @@ -105,13 +107,13 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err if stderrScanner == nil { cancel() cmd.Wait() - return fmt.Errorf("failed to create stderr scanner") + return errors.New("failed to create stderr scanner") } t.Go(func() error { return readLine(stdoutscanner, stdoutChan, errChan) }) t.Go(func() error { - //looks like journalctl closes stderr quite early, so ignore its status (but not its output) + // looks like journalctl closes stderr quite early, so ignore its status (but not its output) return readLine(stderrScanner, stderrChan, nil) }) @@ -120,7 +122,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err case <-t.Dying(): logger.Infof("journalctl datasource %s stopping", j.src) cancel() - cmd.Wait() //avoid zombie process + cmd.Wait() // avoid zombie process return nil case stdoutLine := <-stdoutChan: l := types.Line{} @@ -131,7 +133,9 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err l.Src = j.src l.Process = true l.Module = j.GetName() - linesRead.With(prometheus.Labels{"source": j.src}).Inc() + if j.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": j.src}).Inc() + } var evt types.Event if !j.config.UseTimeMachine { evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} @@ -186,7 +190,7 @@ func (j *JournalCtlSource) UnmarshalConfig(yamlConfig []byte) error { } if len(j.config.Filters) == 0 { - return fmt.Errorf("journalctl_filter is required") + return errors.New("journalctl_filter is required") } j.args = append(args, j.config.Filters...) j.src = fmt.Sprintf("journalctl-%s", strings.Join(j.config.Filters, ".")) @@ -194,8 +198,9 @@ func (j *JournalCtlSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (j *JournalCtlSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (j *JournalCtlSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { j.logger = logger + j.metricsLevel = MetricsLevel err := j.UnmarshalConfig(yamlConfig) if err != nil { @@ -212,14 +217,14 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Labels = labels j.config.UniqueId = uuid - //format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 + // format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 if !strings.HasPrefix(dsn, "journalctl://") { return fmt.Errorf("invalid DSN %s for journalctl source, must start with journalctl://", dsn) } qs := strings.TrimPrefix(dsn, "journalctl://") - if len(qs) == 0 { - return fmt.Errorf("empty journalctl:// DSN") + if qs == "" { + return errors.New("empty journalctl:// DSN") } params, err := url.ParseQuery(qs) @@ -232,7 +237,7 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Filters = append(j.config.Filters, value...) case "log_level": if len(value) != 1 { - return fmt.Errorf("expected zero or one value for 'log_level'") + return errors.New("expected zero or one value for 'log_level'") } lvl, err := log.ParseLevel(value[0]) if err != nil { @@ -262,21 +267,22 @@ func (j *JournalCtlSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb err := j.runJournalCtl(out, t) j.logger.Debug("Oneshot journalctl acquisition is done") return err - } -func (j *JournalCtlSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/journalctl/streaming") return j.runJournalCtl(out, t) }) return nil } + func (j *JournalCtlSource) CanRun() error { - //TODO: add a more precise check on version or something ? + // TODO: add a more precise check on version or something ? _, err := exec.LookPath(journalctlCmd) return err } + func (j *JournalCtlSource) Dump() interface{} { return j } diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index a91fba31b34..c416bb5d23e 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -1,6 +1,7 @@ package journalctlacquisition import ( + "context" "os" "os/exec" "path/filepath" @@ -8,13 +9,15 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { @@ -46,13 +49,11 @@ journalctl_filter: }, } - subLogger := log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger := log.WithField("type", "journalctl") for _, test := range tests { f := JournalCtlSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } @@ -96,9 +97,7 @@ func TestConfigureDSN(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger := log.WithField("type", "journalctl") for _, test := range tests { f := JournalCtlSource{} @@ -144,28 +143,24 @@ journalctl_filter: } for _, ts := range tests { var ( - logger *log.Logger + logger *log.Logger subLogger *log.Entry - hook *test.Hook + hook *test.Hook ) if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = logger.WithField("type", "journalctl") } else { - subLogger = log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = log.WithField("type", "journalctl") } tomb := tomb.Tomb{} out := make(chan types.Event, 100) j := JournalCtlSource{} - err := j.Configure([]byte(ts.config), subLogger) + err := j.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -193,6 +188,7 @@ journalctl_filter: } func TestStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -218,28 +214,24 @@ journalctl_filter: } for _, ts := range tests { var ( - logger *log.Logger + logger *log.Logger subLogger *log.Entry - hook *test.Hook + hook *test.Hook ) if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) - subLogger = logger.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = logger.WithField("type", "journalctl") } else { - subLogger = log.WithFields(log.Fields{ - "type": "journalctl", - }) + subLogger = log.WithField("type", "journalctl") } tomb := tomb.Tomb{} out := make(chan types.Event) j := JournalCtlSource{} - err := j.Configure([]byte(ts.config), subLogger) + err := j.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -260,7 +252,7 @@ journalctl_filter: }() } - err = j.StreamingAcquisition(out, &tomb) + err = j.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if err != nil { diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index 5b6e8fc0d41..9fd5fc2a035 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -23,9 +23,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - dataSourceName = "kafka" -) +var dataSourceName = "kafka" var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -52,9 +50,10 @@ type TLSConfig struct { } type KafkaSource struct { - Config KafkaConfiguration - logger *log.Entry - Reader *kafka.Reader + metricsLevel int + Config KafkaConfiguration + logger *log.Entry + Reader *kafka.Reader } func (k *KafkaSource) GetUuid() string { @@ -81,13 +80,14 @@ func (k *KafkaSource) UnmarshalConfig(yamlConfig []byte) error { k.Config.Mode = configuration.TAIL_MODE } - k.logger.Debugf("successfully unmarshaled kafka configuration : %+v", k.Config) + k.logger.Debugf("successfully parsed kafka configuration : %+v", k.Config) return err } -func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (k *KafkaSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { k.logger = logger + k.metricsLevel = MetricsLevel k.logger.Debugf("start configuring %s source", dataSourceName) @@ -170,7 +170,9 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error { Module: k.GetName(), } k.logger.Tracef("line with message read from topic '%s': %+v", k.Config.Topic, l) - linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() + if k.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() + } var evt types.Event if !k.Config.UseTimeMachine { @@ -200,7 +202,7 @@ func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KafkaSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger.Infof("start reader on brokers '%+v' with topic '%s'", k.Config.Brokers, k.Config.Topic) t.Go(func() error { @@ -274,7 +276,7 @@ func (kc *KafkaConfiguration) NewReader(dialer *kafka.Dialer, logger *log.Entry) ErrorLogger: kafka.LoggerFunc(logger.Errorf), } if kc.GroupID != "" && kc.Partition != 0 { - return &kafka.Reader{}, fmt.Errorf("cannot specify both group_id and partition") + return &kafka.Reader{}, errors.New("cannot specify both group_id and partition") } if kc.GroupID != "" { rConf.GroupID = kc.GroupID diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 92ccd4c7a3f..d796166a6ca 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -15,6 +15,7 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -70,20 +71,18 @@ group_id: crowdsec`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") + for _, test := range tests { k := KafkaSource{} - err := k.Configure([]byte(test.config), subLogger) + err := k.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } -func writeToKafka(w *kafka.Writer, logs []string) { - +func writeToKafka(ctx context.Context, w *kafka.Writer, logs []string) { for idx, log := range logs { - err := w.WriteMessages(context.Background(), kafka.Message{ + err := w.WriteMessages(ctx, kafka.Message{ Key: []byte(strconv.Itoa(idx)), // create an arbitrary message payload for the value Value: []byte(log), @@ -105,7 +104,9 @@ func createTopic(topic string, broker string) { if err != nil { panic(err) } + var controllerConn *kafka.Conn + controllerConn, err = kafka.Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port))) if err != nil { panic(err) @@ -127,9 +128,11 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { name string logs []string @@ -147,9 +150,7 @@ func TestStreamingAcquisition(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") createTopic("crowdsecplaintext", "localhost:9092") @@ -158,28 +159,30 @@ func TestStreamingAcquisition(t *testing.T) { Topic: "crowdsecplaintext", }) if w == nil { - log.Fatalf("Unable to setup a kafka producer") + t.Fatal("Unable to setup a kafka producer") } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { k := KafkaSource{} + err := k.Configure([]byte(` source: kafka brokers: - localhost:9092 -topic: crowdsecplaintext`), subLogger) +topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure kafka source : %s", err) } + tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w, ts.logs) + + go writeToKafka(ctx, w, ts.logs) READLOOP: for { select { @@ -194,13 +197,14 @@ topic: crowdsecplaintext`), subLogger) tomb.Wait() }) } - } func TestStreamingAcquisitionWithSSL(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { name string logs []string @@ -217,9 +221,7 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { }, } - subLogger := log.WithFields(log.Fields{ - "type": "kafka", - }) + subLogger := log.WithField("type", "kafka") createTopic("crowdsecssl", "localhost:9092") @@ -228,13 +230,13 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { Topic: "crowdsecssl", }) if w2 == nil { - log.Fatalf("Unable to setup a kafka producer") + t.Fatal("Unable to setup a kafka producer") } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { k := KafkaSource{} + err := k.Configure([]byte(` source: kafka brokers: @@ -245,17 +247,19 @@ tls: client_cert: ./testdata/kafkaClient.certificate.pem client_key: ./testdata/kafkaClient.key ca_cert: ./testdata/snakeoil-ca-1.crt - `), subLogger) + `), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure kafka source : %s", err) } + tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w2, ts.logs) + + go writeToKafka(ctx, w2, ts.logs) READLOOP: for { select { @@ -270,5 +274,4 @@ tls: tomb.Wait() }) } - } diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index e2cc7996349..ca3a847dbfb 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -3,7 +3,9 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" + "errors" "fmt" "io" "strings" @@ -28,7 +30,7 @@ type KinesisConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` StreamName string `yaml:"stream_name"` StreamARN string `yaml:"stream_arn"` - UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` //Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords + UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` // Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords AwsProfile *string `yaml:"aws_profile"` AwsRegion string `yaml:"aws_region"` AwsEndpoint string `yaml:"aws_endpoint"` @@ -38,6 +40,7 @@ type KinesisConfiguration struct { } type KinesisSource struct { + metricsLevel int Config KinesisConfiguration logger *log.Entry kClient *kinesis.Kinesis @@ -94,7 +97,7 @@ func (k *KinesisSource) newClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } config := aws.NewConfig() if k.Config.AwsRegion != "" { @@ -105,15 +108,15 @@ func (k *KinesisSource) newClient() error { } k.kClient = kinesis.New(sess, config) if k.kClient == nil { - return fmt.Errorf("failed to create kinesis client") + return errors.New("failed to create kinesis client") } return nil } func (k *KinesisSource) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} - } + func (k *KinesisSource) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} } @@ -123,7 +126,7 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { err := yaml.UnmarshalStrict(yamlConfig, &k.Config) if err != nil { - return fmt.Errorf("Cannot parse kinesis datasource configuration: %w", err) + return fmt.Errorf("cannot parse kinesis datasource configuration: %w", err) } if k.Config.Mode == "" { @@ -131,16 +134,16 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { } if k.Config.StreamName == "" && !k.Config.UseEnhancedFanOut { - return fmt.Errorf("stream_name is mandatory when use_enhanced_fanout is false") + return errors.New("stream_name is mandatory when use_enhanced_fanout is false") } if k.Config.StreamARN == "" && k.Config.UseEnhancedFanOut { - return fmt.Errorf("stream_arn is mandatory when use_enhanced_fanout is true") + return errors.New("stream_arn is mandatory when use_enhanced_fanout is true") } if k.Config.ConsumerName == "" && k.Config.UseEnhancedFanOut { - return fmt.Errorf("consumer_name is mandatory when use_enhanced_fanout is true") + return errors.New("consumer_name is mandatory when use_enhanced_fanout is true") } if k.Config.StreamARN != "" && k.Config.StreamName != "" { - return fmt.Errorf("stream_arn and stream_name are mutually exclusive") + return errors.New("stream_arn and stream_name are mutually exclusive") } if k.Config.MaxRetries <= 0 { k.Config.MaxRetries = 10 @@ -149,8 +152,9 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { k.logger = logger + k.metricsLevel = MetricsLevel err := k.UnmarshalConfig(yamlConfig) if err != nil { @@ -167,7 +171,7 @@ func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry) error { } func (k *KinesisSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { - return fmt.Errorf("kinesis datasource does not support command-line acquisition") + return errors.New("kinesis datasource does not support command-line acquisition") } func (k *KinesisSource) GetMode() string { @@ -179,13 +183,12 @@ func (k *KinesisSource) GetName() string { } func (k *KinesisSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("kinesis datasource does not support one-shot acquisition") + return errors.New("kinesis datasource does not support one-shot acquisition") } func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { b := bytes.NewBuffer(record) r, err := gzip.NewReader(b) - if err != nil { k.logger.Error(err) return nil, err @@ -206,7 +209,7 @@ func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubsc func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, streamARN string) error { maxTries := k.Config.MaxRetries - for i := 0; i < maxTries; i++ { + for i := range maxTries { _, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ ConsumerName: aws.String(consumerName), StreamARN: aws.String(streamARN), @@ -247,7 +250,7 @@ func (k *KinesisSource) DeregisterConsumer() error { func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { maxTries := k.Config.MaxRetries - for i := 0; i < maxTries; i++ { + for i := range maxTries { describeOutput, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ ConsumerARN: aws.String(consumerARN), }) @@ -283,17 +286,21 @@ func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutpu func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan types.Event, logger *log.Entry, shardId string) { for _, record := range records { if k.Config.StreamARN != "" { - linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamARN, "shard": shardId}).Inc() - linesRead.With(prometheus.Labels{"stream": k.Config.StreamARN}).Inc() + if k.metricsLevel != configuration.METRICS_NONE { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamARN, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamARN}).Inc() + } } else { - linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamName, "shard": shardId}).Inc() - linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() + if k.metricsLevel != configuration.METRICS_NONE { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamName, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() + } } var data []CloudwatchSubscriptionLogEvent var err error if k.Config.FromSubscription { - //The AWS docs says that the data is base64 encoded - //but apparently GetRecords decodes it for us ? + // The AWS docs says that the data is base64 encoded + // but apparently GetRecords decodes it for us ? data, err = k.decodeFromSubscription(record.Data) if err != nil { logger.Errorf("Cannot decode data: %s", err) @@ -327,10 +334,10 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan } func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEventStreamReader, out chan types.Event, shardId string, streamName string) error { - logger := k.logger.WithFields(log.Fields{"shard_id": shardId}) - //ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately - //and we won't be able to start a new one if this is the first one started by the tomb - //TODO: look into parent shards to see if a shard is closed before starting to read it ? + logger := k.logger.WithField("shard_id", shardId) + // ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately + // and we won't be able to start a new one if this is the first one started by the tomb + // TODO: look into parent shards to see if a shard is closed before starting to read it ? time.Sleep(time.Second) for { select { @@ -390,7 +397,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { return fmt.Errorf("resource part of stream ARN %s does not start with stream/", k.Config.StreamARN) } - k.logger = k.logger.WithFields(log.Fields{"stream": parsedARN.Resource[7:]}) + k.logger = k.logger.WithField("stream", parsedARN.Resource[7:]) k.logger.Info("starting kinesis acquisition with enhanced fan-out") err = k.DeregisterConsumer() if err != nil { @@ -413,7 +420,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { case <-t.Dying(): k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) @@ -424,7 +431,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { if k.shardReaderTomb.Err() != nil { return k.shardReaderTomb.Err() } - //All goroutines have exited without error, so a resharding event, start again + // All goroutines have exited without error, so a resharding event, start again k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") continue } @@ -432,17 +439,19 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { } func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { - logger := k.logger.WithFields(log.Fields{"shard": shardId}) + logger := k.logger.WithField("shard", shardId) logger.Debugf("Starting to read shard") - sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ShardId: aws.String(shardId), + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardId), StreamName: &k.Config.StreamName, - ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest)}) + ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest), + }) if err != nil { logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) } it := sharIt.ShardIterator - //AWS recommends to wait for a second between calls to GetRecords for a given shard + // AWS recommends to wait for a second between calls to GetRecords for a given shard ticker := time.NewTicker(time.Second) for { select { @@ -453,7 +462,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro switch err.(type) { case *kinesis.ProvisionedThroughputExceededException: logger.Warn("Provisioned throughput exceeded") - //TODO: implement exponential backoff + // TODO: implement exponential backoff continue case *kinesis.ExpiredIteratorException: logger.Warn("Expired iterator") @@ -478,7 +487,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro } func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error { - k.logger = k.logger.WithFields(log.Fields{"stream": k.Config.StreamName}) + k.logger = k.logger.WithField("stream", k.Config.StreamName) k.logger.Info("starting kinesis acquisition from shards") for { shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ @@ -499,7 +508,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error case <-t.Dying(): k.logger.Info("kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves return nil case <-k.shardReaderTomb.Dying(): reason := k.shardReaderTomb.Err() @@ -513,14 +522,13 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { return k.EnhancedRead(out, t) - } else { - return k.ReadFromStream(out, t) } + return k.ReadFromStream(out, t) }) return nil } diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 662d6040e0f..027cbde9240 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "fmt" "net" @@ -12,15 +13,17 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func getLocalStackEndpoint() (string, error) { @@ -29,7 +32,7 @@ func getLocalStackEndpoint() (string, error) { v = strings.TrimPrefix(v, "http://") _, err := net.Dial("tcp", v) if err != nil { - return "", fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err) + return "", fmt.Errorf("while dialing %s: %w: aws endpoint isn't available", v, err) } } return endpoint, nil @@ -58,8 +61,8 @@ func GenSubObject(i int) []byte { gz := gzip.NewWriter(&b) gz.Write(body) gz.Close() - //AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point - //localstack does not do it, so let's just write a raw gzipped stream + // AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point + // localstack does not do it, so let's just write a raw gzipped stream return b.Bytes() } @@ -70,7 +73,7 @@ func WriteToStream(streamName string, count int, shards int, sub bool) { } sess := session.Must(session.NewSession()) kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) - for i := 0; i < count; i++ { + for i := range count { partition := "partition" if shards != 1 { partition = fmt.Sprintf("partition-%d", i%shards) @@ -97,10 +100,10 @@ func TestMain(m *testing.M) { os.Setenv("AWS_ACCESS_KEY_ID", "foobar") os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar") - //delete_streams() - //create_streams() + // delete_streams() + // create_streams() code := m.Run() - //delete_streams() + // delete_streams() os.Exit(code) } @@ -138,17 +141,16 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "kinesis", - }) + subLogger := log.WithField("type", "kinesis") for _, test := range tests { f := KinesisSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } func TestReadFromStream(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -170,22 +172,20 @@ stream_name: stream-1-shard`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) - for i := 0; i < test.count; i++ { + for i := range test.count { e := <-out assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) } @@ -195,6 +195,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -216,23 +217,21 @@ stream_name: stream-2-shards`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) c := 0 - for i := 0; i < test.count; i++ { + for range test.count { <-out c += 1 } @@ -243,6 +242,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -265,22 +265,20 @@ from_subscription: true`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis"), configuration.METRICS_NONE) if err != nil { t.Fatalf("Error configuring source: %s", err) } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, true) - for i := 0; i < test.count; i++ { + for i := range test.count { e := <-out assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) } @@ -311,9 +309,7 @@ use_enhanced_fanout: true`, for _, test := range tests { f := KinesisSource{} config := fmt.Sprintf(test.config, endpoint) - err := f.Configure([]byte(config), log.WithFields(log.Fields{ - "type": "kinesis", - })) + err := f.Configure([]byte(config), log.WithField("type", "kinesis")) if err != nil { t.Fatalf("Error configuring source: %s", err) } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index ee44bd01ae2..f979b044dcc 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -3,6 +3,7 @@ package kubernetesauditacquisition import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -28,12 +29,13 @@ type KubernetesAuditConfiguration struct { } type KubernetesAuditSource struct { - config KubernetesAuditConfiguration - logger *log.Entry - mux *http.ServeMux - server *http.Server - outChan chan types.Event - addr string + metricsLevel int + config KubernetesAuditConfiguration + logger *log.Entry + mux *http.ServeMux + server *http.Server + outChan chan types.Event + addr string } var eventCount = prometheus.NewCounterVec( @@ -72,15 +74,15 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { ka.config = k8sConfig if ka.config.ListenAddr == "" { - return fmt.Errorf("listen_addr cannot be empty") + return errors.New("listen_addr cannot be empty") } if ka.config.ListenPort == 0 { - return fmt.Errorf("listen_port cannot be empty") + return errors.New("listen_port cannot be empty") } if ka.config.WebhookPath == "" { - return fmt.Errorf("webhook_path cannot be empty") + return errors.New("webhook_path cannot be empty") } if ka.config.WebhookPath[0] != '/' { @@ -93,8 +95,9 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry) error { +func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry, MetricsLevel int) error { ka.logger = logger + ka.metricsLevel = MetricsLevel err := ka.UnmarshalConfig(config) if err != nil { @@ -117,7 +120,7 @@ func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry) err } func (ka *KubernetesAuditSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { - return fmt.Errorf("k8s-audit datasource does not support command-line acquisition") + return errors.New("k8s-audit datasource does not support command-line acquisition") } func (ka *KubernetesAuditSource) GetMode() string { @@ -129,10 +132,10 @@ func (ka *KubernetesAuditSource) GetName() string { } func (ka *KubernetesAuditSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("k8s-audit datasource does not support one-shot acquisition") + return errors.New("k8s-audit datasource does not support one-shot acquisition") } -func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") @@ -161,7 +164,9 @@ func (ka *KubernetesAuditSource) Dump() interface{} { } func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.Request) { - requestCount.WithLabelValues(ka.addr).Inc() + if ka.metricsLevel != configuration.METRICS_NONE { + requestCount.WithLabelValues(ka.addr).Inc() + } if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return @@ -185,10 +190,12 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R remoteIP := strings.Split(r.RemoteAddr, ":")[0] for _, auditEvent := range auditEvents.Items { - eventCount.WithLabelValues(ka.addr).Inc() + if ka.metricsLevel != configuration.METRICS_NONE { + eventCount.WithLabelValues(ka.addr).Inc() + } bytesEvent, err := json.Marshal(auditEvent) if err != nil { - ka.logger.Errorf("Error marshaling audit event: %s", err) + ka.logger.Errorf("Error serializing audit event: %s", err) continue } ka.logger.Tracef("Got audit event: %s", string(bytesEvent)) diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index c3502c95685..a086a756e4a 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,16 +1,19 @@ package kubernetesauditacquisition import ( + "context" "net/http/httptest" "strings" "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { @@ -45,12 +48,12 @@ listen_addr: 0.0.0.0`, err := f.UnmarshalConfig([]byte(test.config)) assert.Contains(t, err.Error(), test.expectedErr) - }) } } func TestInvalidConfig(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -66,9 +69,7 @@ webhook_path: /k8s-audit`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "k8s-audit", - }) + subLogger := log.WithField("type", "k8s-audit") for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -81,10 +82,10 @@ webhook_path: /k8s-audit`, require.NoError(t, err) - err = f.Configure([]byte(test.config), subLogger) + err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) @@ -99,6 +100,7 @@ webhook_path: /k8s-audit`, } func TestHandler(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -229,9 +231,7 @@ webhook_path: /k8s-audit`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "k8s-audit", - }) + subLogger := log.WithField("type", "k8s-audit") for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -253,21 +253,21 @@ webhook_path: /k8s-audit`, f := KubernetesAuditSource{} err := f.UnmarshalConfig([]byte(test.config)) require.NoError(t, err) - err = f.Configure([]byte(test.config), subLogger) + err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - //time.Sleep(1 * time.Second) + // time.Sleep(1 * time.Second) require.NoError(t, err) tb.Kill(nil) diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 8451a86fcdf..846e833abea 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -11,11 +12,11 @@ import ( "strconv" "time" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/gorilla/websocket" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) type LokiClient struct { @@ -25,6 +26,7 @@ type LokiClient struct { t *tomb.Tomb fail_start time.Time currentTickerInterval time.Duration + requestHeaders map[string]string } type Config struct { @@ -73,6 +75,7 @@ func (lc *LokiClient) resetFailStart() { } lc.fail_start = time.Time{} } + func (lc *LokiClient) shouldRetry() bool { if lc.fail_start.IsZero() { lc.Logger.Warningf("loki is not available, will retry for %s", lc.config.FailMaxDuration) @@ -105,7 +108,7 @@ func (lc *LokiClient) decreaseTicker(ticker *time.Ticker) { } } -func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQueryRangeResponse, infinite bool) error { +func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQueryRangeResponse, infinite bool) error { lc.currentTickerInterval = 100 * time.Millisecond ticker := time.NewTicker(lc.currentTickerInterval) defer ticker.Stop() @@ -116,36 +119,34 @@ func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQu case <-lc.t.Dying(): return lc.t.Err() case <-ticker.C: - resp, err := http.Get(uri) + resp, err := lc.Get(uri) if err != nil { if ok := lc.shouldRetry(); !ok { - return errors.Wrapf(err, "error querying range") - } else { - lc.increaseTicker(ticker) - continue + return fmt.Errorf("error querying range: %w", err) } + lc.increaseTicker(ticker) + continue } if resp.StatusCode != http.StatusOK { + lc.Logger.Warnf("bad HTTP response code for query range: %d", resp.StatusCode) body, _ := io.ReadAll(resp.Body) resp.Body.Close() if ok := lc.shouldRetry(); !ok { - return errors.Wrapf(err, "bad HTTP response code: %d: %s", resp.StatusCode, string(body)) - } else { - lc.increaseTicker(ticker) - continue + return fmt.Errorf("bad HTTP response code: %d: %s: %w", resp.StatusCode, string(body), err) } + lc.increaseTicker(ticker) + continue } var lq LokiQueryRangeResponse if err := json.NewDecoder(resp.Body).Decode(&lq); err != nil { resp.Body.Close() if ok := lc.shouldRetry(); !ok { - return errors.Wrapf(err, "error decoding Loki response") - } else { - lc.increaseTicker(ticker) - continue + return fmt.Errorf("error decoding Loki response: %w", err) } + lc.increaseTicker(ticker) + continue } resp.Body.Close() lc.Logger.Tracef("Got response: %+v", lq) @@ -186,7 +187,6 @@ func (lc *LokiClient) getURLFor(endpoint string, params map[string]string) strin u.RawQuery = queryParams.Encode() u.Path, err = url.JoinPath(lc.config.LokiPrefix, u.Path, endpoint) - if err != nil { return "" } @@ -215,7 +215,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error { return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") - resp, err := http.Get(url) + resp, err := lc.Get(url) if err != nil { lc.Logger.Warnf("Error checking if Loki is ready: %s", err) continue @@ -251,23 +251,22 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { } requestHeader := http.Header{} - for k, v := range lc.config.Headers { + for k, v := range lc.requestHeaders { requestHeader.Add(k, v) } - requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr()) lc.Logger.Infof("Connecting to %s", u) - conn, _, err := dialer.Dial(u, requestHeader) + conn, _, err := dialer.Dial(u, requestHeader) if err != nil { lc.Logger.Errorf("Error connecting to websocket, err: %s", err) - return responseChan, fmt.Errorf("error connecting to websocket") + return responseChan, errors.New("error connecting to websocket") } lc.t.Go(func() error { for { jsonResponse := &LokiResponse{} - err = conn.ReadJSON(jsonResponse) + err = conn.ReadJSON(jsonResponse) if err != nil { lc.Logger.Errorf("Error reading from websocket: %s", err) return fmt.Errorf("websocket error: %w", err) @@ -293,23 +292,33 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) - requestHeader := http.Header{} - for k, v := range lc.config.Headers { - requestHeader.Add(k, v) - } - - if lc.config.Username != "" || lc.config.Password != "" { - requestHeader.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(lc.config.Username+":"+lc.config.Password))) - } - - requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr()) lc.Logger.Infof("Connecting to %s", url) lc.t.Go(func() error { - return lc.queryRange(url, ctx, c, infinite) + return lc.queryRange(ctx, url, c, infinite) }) return c } +// Create a wrapper for http.Get to be able to set headers and auth +func (lc *LokiClient) Get(url string) (*http.Response, error) { + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + for k, v := range lc.requestHeaders { + request.Header.Add(k, v) + } + return http.DefaultClient.Do(request) +} + func NewLokiClient(config Config) *LokiClient { - return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config} + headers := make(map[string]string) + for k, v := range config.Headers { + headers[k] = v + } + if config.Username != "" || config.Password != "" { + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(config.Username+":"+config.Password)) + } + headers["User-Agent"] = useragent.Default() + return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config, requestHeaders: headers} } diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go index 555deefe25a..f867feeb84b 100644 --- a/pkg/acquisition/modules/loki/loki.go +++ b/pkg/acquisition/modules/loki/loki.go @@ -6,20 +6,20 @@ https://grafana.com/docs/loki/latest/api/#get-lokiapiv1tail import ( "context" + "errors" "fmt" "net/url" "strconv" "strings" "time" - "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" tomb "gopkg.in/tomb.v2" yaml "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" - lokiclient "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki/internal/lokiclient" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki/internal/lokiclient" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -57,7 +57,8 @@ type LokiConfiguration struct { } type LokiSource struct { - Config LokiConfiguration + metricsLevel int + Config LokiConfiguration Client *lokiclient.LokiClient @@ -118,9 +119,10 @@ func (l *LokiSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (l *LokiSource) Configure(config []byte, logger *log.Entry) error { +func (l *LokiSource) Configure(config []byte, logger *log.Entry, MetricsLevel int) error { l.Config = LokiConfiguration{} l.logger = logger + l.metricsLevel = MetricsLevel err := l.UnmarshalConfig(config) if err != nil { return err @@ -302,7 +304,9 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri ll.Process = true ll.Module = l.GetName() - linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc() + if l.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc() + } expectMode := types.LIVE if l.Config.UseTimeMachine { expectMode = types.TIMEMACHINE @@ -315,9 +319,9 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri } } -func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) + readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) defer cancel() err := l.Client.Ready(readyCtx) if err != nil { @@ -325,7 +329,7 @@ func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er } ll := l.logger.WithField("websocket_url", l.lokiWebsocket) t.Go(func() error { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() respChan := l.Client.QueryRange(ctx, true) if err != nil { diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index fae2e3aa98f..627200217f5 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -2,6 +2,7 @@ package loki_test import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -13,19 +14,18 @@ import ( "testing" "time" - "context" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + tomb "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" "github.com/crowdsecurity/crowdsec/pkg/types" - log "github.com/sirupsen/logrus" - tomb "gopkg.in/tomb.v2" - "gotest.tools/v3/assert" ) func TestConfiguration(t *testing.T) { - log.Infof("Test 'TestConfigure'") tests := []struct { @@ -95,7 +95,6 @@ query: > delayFor: 1 * time.Second, }, { - config: ` mode: tail source: loki @@ -111,7 +110,6 @@ query: > testName: "Correct config with password", }, { - config: ` mode: tail source: loki @@ -124,25 +122,27 @@ query: > testName: "Invalid DelayFor", }, } - subLogger := log.WithFields(log.Fields{ - "type": "loki", - }) + subLogger := log.WithField("type", "loki") + for _, test := range tests { t.Run(test.testName, func(t *testing.T) { lokiSource := loki.LokiSource{} - err := lokiSource.Configure([]byte(test.config), subLogger) + err := lokiSource.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) + if test.password != "" { p := lokiSource.Config.Auth.Password if test.password != p { t.Fatalf("Password mismatch : %s != %s", test.password, p) } } + if test.waitForReady != 0 { if lokiSource.Config.WaitForReady != test.waitForReady { t.Fatalf("Wrong WaitForReady %v != %v", lokiSource.Config.WaitForReady, test.waitForReady) } } + if test.delayFor != 0 { if lokiSource.Config.DelayFor != test.delayFor { t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) @@ -154,6 +154,7 @@ query: > func TestConfigureDSN(t *testing.T) { log.Infof("Test 'TestConfigureDSN'") + tests := []struct { name string dsn string @@ -218,7 +219,9 @@ func TestConfigureDSN(t *testing.T) { "type": "loki", "name": test.name, }) + t.Logf("Test : %s", test.name) + lokiSource := &loki.LokiSource{} err := lokiSource.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") cstest.AssertErrorContains(t, err, test.expectedErr) @@ -234,17 +237,20 @@ func TestConfigureDSN(t *testing.T) { t.Fatalf("Password mismatch : %s != %s", test.password, p) } } + if test.scheme != "" { url, _ := url.Parse(lokiSource.Config.URL) if test.scheme != url.Scheme { t.Fatalf("Schema mismatch : %s != %s", test.scheme, url.Scheme) } } + if test.waitForReady != 0 { if lokiSource.Config.WaitForReady != test.waitForReady { t.Fatalf("Wrong WaitForReady %v != %v", lokiSource.Config.WaitForReady, test.waitForReady) } } + if test.delayFor != 0 { if lokiSource.Config.DelayFor != test.delayFor { t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) @@ -253,7 +259,7 @@ func TestConfigureDSN(t *testing.T) { } } -func feedLoki(logger *log.Entry, n int, title string) error { +func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error { streams := LogStreams{ Streams: []LogStream{ { @@ -266,26 +272,42 @@ func feedLoki(logger *log.Entry, n int, title string) error { }, }, } - for i := 0; i < n; i++ { + for i := range n { streams.Streams[0].Values[i] = LogValue{ Time: time.Now(), Line: fmt.Sprintf("Log line #%d %v", i, title), } } + buff, err := json.Marshal(streams) if err != nil { return err } - resp, err := http.Post("http://127.0.0.1:3100/loki/api/v1/push", "application/json", bytes.NewBuffer(buff)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) if err != nil { return err } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Scope-Orgid", "1234") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { b, _ := io.ReadAll(resp.Body) logger.Error(string(b)) + return fmt.Errorf("Bad post status %d", resp.StatusCode) } + logger.Info(n, " Events sent") + return nil } @@ -293,9 +315,11 @@ func TestOneShotAcquisition(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") + title := time.Now().String() // Loki will be messy, with a lot of stuff, lets use a unique key tests := []struct { config string @@ -306,6 +330,8 @@ mode: cat source: loki url: http://127.0.0.1:3100 query: '{server="demo",key="%s"}' +headers: + x-scope-orgid: "1234" since: 1h `, title), }, @@ -313,35 +339,39 @@ since: 1h for _, ts := range tests { logger := log.New() - subLogger := logger.WithFields(log.Fields{ - "type": "loki", - }) + subLogger := logger.WithField("type", "loki") lokiSource := loki.LokiSource{} - err := lokiSource.Configure([]byte(ts.config), subLogger) + err := lokiSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = feedLoki(subLogger, 20, title) + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 20, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } out := make(chan types.Event) read := 0 + go func() { for { <-out + read++ } }() + lokiTomb := tomb.Tomb{} + err = lokiSource.OneShotAcquisition(out, &lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } - assert.Equal(t, 20, read) + assert.Equal(t, 20, read) } } @@ -349,9 +379,11 @@ func TestStreamingAcquisition(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") + title := time.Now().String() tests := []struct { name string @@ -362,31 +394,34 @@ func TestStreamingAcquisition(t *testing.T) { }{ { name: "Bad port", - config: ` -mode: tail + config: `mode: tail source: loki -url: http://127.0.0.1:3101 +url: "http://127.0.0.1:3101" +headers: + x-scope-orgid: "1234" query: > - {server="demo"} -`, // No Loki server here + {server="demo"}`, // No Loki server here expectedErr: "", streamErr: `loki is not ready: context deadline exceeded`, expectedLines: 0, }, { name: "ok", - config: ` -mode: tail + config: `mode: tail source: loki -url: http://127.0.0.1:3100 +url: "http://127.0.0.1:3100" +headers: + x-scope-orgid: "1234" query: > - {server="demo"} -`, + {server="demo"}`, expectedErr: "", streamErr: "", expectedLines: 20, }, } + + ctx := context.Background() + for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { logger := log.New() @@ -398,33 +433,39 @@ query: > out := make(chan types.Event) lokiTomb := tomb.Tomb{} lokiSource := loki.LokiSource{} - err := lokiSource.Configure([]byte(ts.config), subLogger) + + err := lokiSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(out, &lokiTomb) + + err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) cstest.AssertErrorContains(t, err, ts.streamErr) if ts.streamErr != "" { return } - time.Sleep(time.Second * 2) //We need to give time to start reading from the WS + time.Sleep(time.Second * 2) // We need to give time to start reading from the WS + readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) + readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 readTomb.Go(func() error { defer cancel() + for { select { case <-readCtx.Done(): return readCtx.Err() case evt := <-out: count++ + if !strings.HasSuffix(evt.Line.Raw, title) { return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) } + if count == ts.expectedLines { return nil } @@ -432,57 +473,67 @@ query: > } }) - err = feedLoki(subLogger, ts.expectedLines, title) + err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } err = readTomb.Wait() + cancel() + if err != nil { t.Fatalf("Unexpected error : %s", err) } - assert.Equal(t, count, ts.expectedLines) + + assert.Equal(t, ts.expectedLines, count) }) } - } func TestStopStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + config := ` mode: tail source: loki url: http://127.0.0.1:3100 +headers: + x-scope-orgid: "1234" query: > {server="demo"} ` logger := log.New() - subLogger := logger.WithFields(log.Fields{ - "type": "loki", - }) + subLogger := logger.WithField("type", "loki") title := time.Now().String() lokiSource := loki.LokiSource{} - err := lokiSource.Configure([]byte(config), subLogger) + + err := lokiSource.Configure([]byte(config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("Unexpected error : %s", err) } + out := make(chan types.Event) lokiTomb := &tomb.Tomb{} - err = lokiSource.StreamingAcquisition(out, lokiTomb) + + err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } + time.Sleep(time.Second * 2) - err = feedLoki(subLogger, 1, title) + + err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } lokiTomb.Kill(nil) + err = lokiTomb.Wait() if err != nil { t.Fatalf("Unexpected error : %s", err) @@ -508,5 +559,6 @@ func (l *LogValue) MarshalJSON() ([]byte, error) { if err != nil { return nil, err } + return []byte(fmt.Sprintf(`["%d",%s]`, l.Time.UnixNano(), string(line))), nil } diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index 651d40d3d50..ed1964edebf 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -38,7 +38,7 @@ type S3Configuration struct { AwsEndpoint string `yaml:"aws_endpoint"` BucketName string `yaml:"bucket_name"` Prefix string `yaml:"prefix"` - Key string `yaml:"-"` //Only for DSN acquisition + Key string `yaml:"-"` // Only for DSN acquisition PollingMethod string `yaml:"polling_method"` PollingInterval int `yaml:"polling_interval"` SQSName string `yaml:"sqs_name"` @@ -47,15 +47,16 @@ type S3Configuration struct { } type S3Source struct { - Config S3Configuration - logger *log.Entry - s3Client s3iface.S3API - sqsClient sqsiface.SQSAPI - readerChan chan S3Object - t *tomb.Tomb - out chan types.Event - ctx aws.Context - cancel context.CancelFunc + MetricsLevel int + Config S3Configuration + logger *log.Entry + s3Client s3iface.S3API + sqsClient sqsiface.SQSAPI + readerChan chan S3Object + t *tomb.Tomb + out chan types.Event + ctx aws.Context + cancel context.CancelFunc } type S3Object struct { @@ -92,10 +93,12 @@ type S3Event struct { } `json:"detail"` } -const PollMethodList = "list" -const PollMethodSQS = "sqs" -const SQSFormatEventBridge = "eventbridge" -const SQSFormatS3Notification = "s3notification" +const ( + PollMethodList = "list" + PollMethodSQS = "sqs" + SQSFormatEventBridge = "eventbridge" + SQSFormatS3Notification = "s3notification" +) var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -130,7 +133,6 @@ func (s *S3Source) newS3Client() error { } sess, err := session.NewSessionWithOptions(options) - if err != nil { return fmt.Errorf("failed to create aws session: %w", err) } @@ -145,7 +147,7 @@ func (s *S3Source) newS3Client() error { s.s3Client = s3.New(sess, config) if s.s3Client == nil { - return fmt.Errorf("failed to create S3 client") + return errors.New("failed to create S3 client") } return nil @@ -166,7 +168,7 @@ func (s *S3Source) newSQSClient() error { } if sess == nil { - return fmt.Errorf("failed to create aws session") + return errors.New("failed to create aws session") } config := aws.NewConfig() if s.Config.AwsRegion != "" { @@ -177,7 +179,7 @@ func (s *S3Source) newSQSClient() error { } s.sqsClient = sqs.New(sess, config) if s.sqsClient == nil { - return fmt.Errorf("failed to create SQS client") + return errors.New("failed to create SQS client") } return nil } @@ -204,7 +206,7 @@ func (s *S3Source) getBucketContent() ([]*s3.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content for %s", s.Config.BucketName) bucketObjects := make([]*s3.Object, 0) - var continuationToken *string = nil + var continuationToken *string for { out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), @@ -250,16 +252,15 @@ func (s *S3Source) listPoll() error { continue } for i := len(bucketObjects) - 1; i >= 0; i-- { - if bucketObjects[i].LastModified.After(lastObjectDate) { - newObject = true - logger.Debugf("Found new object %s", *bucketObjects[i].Key) - s.readerChan <- S3Object{ - Bucket: s.Config.BucketName, - Key: *bucketObjects[i].Key, - } - } else { + if !bucketObjects[i].LastModified.After(lastObjectDate) { break } + newObject = true + logger.Debugf("Found new object %s", *bucketObjects[i].Key) + s.readerChan <- S3Object{ + Bucket: s.Config.BucketName, + Key: *bucketObjects[i].Key, + } } if newObject { lastObjectDate = *bucketObjects[len(bucketObjects)-1].LastModified @@ -277,7 +278,7 @@ func extractBucketAndPrefixFromEventBridge(message *string) (string, string, err if eventBody.Detail.Bucket.Name != "" { return eventBody.Detail.Bucket.Name, eventBody.Detail.Object.Key, nil } - return "", "", fmt.Errorf("invalid event body for event bridge format") + return "", "", errors.New("invalid event body for event bridge format") } func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) { @@ -287,7 +288,7 @@ func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) return "", "", err } if len(s3notifBody.Records) == 0 { - return "", "", fmt.Errorf("no records found in S3 notification") + return "", "", errors.New("no records found in S3 notification") } if !strings.HasPrefix(s3notifBody.Records[0].EventName, "ObjectCreated:") { return "", "", fmt.Errorf("event %s is not supported", s3notifBody.Records[0].EventName) @@ -296,19 +297,20 @@ func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) } func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, error) { - if s.Config.SQSFormat == SQSFormatEventBridge { + switch s.Config.SQSFormat { + case SQSFormatEventBridge: bucket, key, err := extractBucketAndPrefixFromEventBridge(message) if err != nil { return "", "", err } return bucket, key, nil - } else if s.Config.SQSFormat == SQSFormatS3Notification { + case SQSFormatS3Notification: bucket, key, err := extractBucketAndPrefixFromS3Notif(message) if err != nil { return "", "", err } return bucket, key, nil - } else { + default: bucket, key, err := extractBucketAndPrefixFromEventBridge(message) if err == nil { s.Config.SQSFormat = SQSFormatEventBridge @@ -319,7 +321,7 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro s.Config.SQSFormat = SQSFormatS3Notification return bucket, key, nil } - return "", "", fmt.Errorf("SQS message format not supported") + return "", "", errors.New("SQS message format not supported") } } @@ -336,7 +338,7 @@ func (s *S3Source) sqsPoll() error { out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), - WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? + WaitTimeSeconds: aws.Int64(20), // Probably no need to make it configurable ? }) if err != nil { logger.Errorf("Error while polling SQS: %s", err) @@ -345,11 +347,13 @@ func (s *S3Source) sqsPoll() error { logger.Tracef("SQS output: %v", out) logger.Debugf("Received %d messages from SQS", len(out.Messages)) for _, message := range out.Messages { - sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() + } bucket, key, err := s.extractBucketAndPrefix(message.Body) if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) - //Always delete the message to avoid infinite loop + // Always delete the message to avoid infinite loop _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -375,7 +379,7 @@ func (s *S3Source) sqsPoll() error { } func (s *S3Source) readFile(bucket string, key string) error { - //TODO: Handle SSE-C + // TODO: Handle SSE-C var scanner *bufio.Scanner logger := s.logger.WithFields(log.Fields{ @@ -388,14 +392,13 @@ func (s *S3Source) readFile(bucket string, key string) error { Bucket: aws.String(bucket), Key: aws.String(key), }) - if err != nil { return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err) } defer output.Body.Close() if strings.HasSuffix(key, ".gz") { - //This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) + // This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) header := make([]byte, 2) _, err := output.Body.Read(header) if err != nil { @@ -426,14 +429,20 @@ func (s *S3Source) readFile(bucket string, key string) error { default: text := scanner.Text() logger.Tracef("Read line %s", text) - linesRead.WithLabelValues(bucket).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + linesRead.WithLabelValues(bucket).Inc() + } l := types.Line{} l.Raw = text l.Labels = s.Config.Labels l.Time = time.Now().UTC() l.Process = true l.Module = s.GetName() - l.Src = bucket + "/" + key + if s.MetricsLevel == configuration.METRICS_FULL { + l.Src = bucket + "/" + key + } else if s.MetricsLevel == configuration.METRICS_AGGREGATE { + l.Src = bucket + } var evt types.Event if !s.Config.UseTimeMachine { evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} @@ -446,7 +455,9 @@ func (s *S3Source) readFile(bucket string, key string) error { if err := scanner.Err(); err != nil { return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err) } - objectsRead.WithLabelValues(bucket).Inc() + if s.MetricsLevel != configuration.METRICS_NONE { + objectsRead.WithLabelValues(bucket).Inc() + } return nil } @@ -457,6 +468,7 @@ func (s *S3Source) GetUuid() string { func (s *S3Source) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } + func (s *S3Source) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } @@ -487,15 +499,15 @@ func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { } if s.Config.BucketName != "" && s.Config.SQSName != "" { - return fmt.Errorf("bucket_name and sqs_name are mutually exclusive") + return errors.New("bucket_name and sqs_name are mutually exclusive") } if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" { - return fmt.Errorf("sqs_name is required when using sqs polling method") + return errors.New("sqs_name is required when using sqs polling method") } if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList { - return fmt.Errorf("bucket_name is required") + return errors.New("bucket_name is required") } if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification { @@ -505,7 +517,7 @@ func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry) error { +func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry, metricsLevel int) error { err := s.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -557,11 +569,11 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * }) dsn = strings.TrimPrefix(dsn, "s3://") args := strings.Split(dsn, "?") - if len(args[0]) == 0 { - return fmt.Errorf("empty s3:// DSN") + if args[0] == "" { + return errors.New("empty s3:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse s3 args: %w", err) @@ -600,7 +612,7 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * pathParts := strings.Split(args[0], "/") s.logger.Debugf("pathParts: %v", pathParts) - //FIXME: handle s3://bucket/ + // FIXME: handle s3://bucket/ if len(pathParts) == 1 { s.Config.BucketName = pathParts[0] s.Config.Prefix = "" @@ -643,7 +655,7 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return err } } else { - //No key, get everything in the bucket based on the prefix + // No key, get everything in the bucket based on the prefix objects, err := s.getBucketContent() if err != nil { return err @@ -659,11 +671,11 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return nil } -func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.t = t s.out = out - s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? + s.ctx, s.cancel = context.WithCancel(ctx) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { s.readManager() diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 02423b1392c..05a974517a0 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -14,10 +14,12 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestBadConfiguration(t *testing.T) { @@ -66,7 +68,7 @@ sqs_name: foobar for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := S3Source{} - err := f.Configure([]byte(test.config), nil) + err := f.Configure([]byte(test.config), nil, configuration.METRICS_NONE) if err == nil { t.Fatalf("expected error, got none") } @@ -111,7 +113,7 @@ polling_method: list t.Run(test.name, func(t *testing.T) { f := S3Source{} logger := log.NewEntry(log.New()) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -265,13 +267,12 @@ func TestDSNAcquis(t *testing.T) { time.Sleep(2 * time.Second) done <- true assert.Equal(t, test.expectedCount, linesRead) - }) } - } func TestListPolling(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -306,7 +307,7 @@ prefix: foo/ f := S3Source{} logger := log.NewEntry(log.New()) logger.Logger.SetLevel(log.TraceLevel) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -331,8 +332,7 @@ prefix: foo/ } }() - err = f.StreamingAcquisition(out, &tb) - + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -349,6 +349,7 @@ prefix: foo/ } func TestSQSPoll(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -381,7 +382,7 @@ sqs_name: test linesRead := 0 f := S3Source{} logger := log.NewEntry(log.New()) - err := f.Configure([]byte(test.config), logger) + err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -412,8 +413,7 @@ sqs_name: test } }() - err = f.StreamingAcquisition(out, &tb) - + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go index 3b59a806b8b..66d842ed519 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go @@ -1,7 +1,7 @@ package rfc3164 import ( - "fmt" + "errors" "time" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog/internal/parser/utils" @@ -52,7 +52,7 @@ func (r *RFC3164) parsePRI() error { pri := 0 if r.buf[r.position] != '<' { - return fmt.Errorf("PRI must start with '<'") + return errors.New("PRI must start with '<'") } r.position++ @@ -64,18 +64,18 @@ func (r *RFC3164) parsePRI() error { break } if c < '0' || c > '9' { - return fmt.Errorf("PRI must be a number") + return errors.New("PRI must be a number") } pri = pri*10 + int(c-'0') r.position++ } if pri > 999 { - return fmt.Errorf("PRI must be up to 3 characters long") + return errors.New("PRI must be up to 3 characters long") } if r.position == r.len && r.buf[r.position-1] != '>' { - return fmt.Errorf("PRI must end with '>'") + return errors.New("PRI must end with '>'") } r.PRI = pri @@ -98,7 +98,7 @@ func (r *RFC3164) parseTimestamp() error { } } if !validTs { - return fmt.Errorf("timestamp is not valid") + return errors.New("timestamp is not valid") } if r.useCurrentYear { if r.Timestamp.Year() == 0 { @@ -122,11 +122,11 @@ func (r *RFC3164) parseHostname() error { } if r.strictHostname { if !utils.IsValidHostnameOrIP(string(hostname)) { - return fmt.Errorf("hostname is not valid") + return errors.New("hostname is not valid") } } if len(hostname) == 0 { - return fmt.Errorf("hostname is empty") + return errors.New("hostname is empty") } r.Hostname = string(hostname) return nil @@ -147,7 +147,7 @@ func (r *RFC3164) parseTag() error { r.position++ } if len(tag) == 0 { - return fmt.Errorf("tag is empty") + return errors.New("tag is empty") } r.Tag = string(tag) @@ -167,7 +167,7 @@ func (r *RFC3164) parseTag() error { break } if c < '0' || c > '9' { - return fmt.Errorf("pid inside tag must be a number") + return errors.New("pid inside tag must be a number") } tmpPid = append(tmpPid, c) r.position++ @@ -175,7 +175,7 @@ func (r *RFC3164) parseTag() error { } if hasPid && !pidEnd { - return fmt.Errorf("pid inside tag must be closed with ']'") + return errors.New("pid inside tag must be closed with ']'") } if hasPid { @@ -191,7 +191,7 @@ func (r *RFC3164) parseMessage() error { } if r.position == r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } c := r.buf[r.position] @@ -202,7 +202,7 @@ func (r *RFC3164) parseMessage() error { for { if r.position >= r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } c := r.buf[r.position] if c != ' ' { @@ -219,7 +219,7 @@ func (r *RFC3164) parseMessage() error { func (r *RFC3164) Parse(message []byte) error { r.len = len(message) if r.len == 0 { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } r.buf = message diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go index 48772d596f4..3af6614bce6 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestPri(t *testing.T) { @@ -22,33 +26,24 @@ func TestPri(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parsePRI() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.PRI != test.expected { - t.Errorf("expected %d, got %d", test.expected, r.PRI) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.PRI) }) } } func TestTimestamp(t *testing.T) { - tests := []struct { input string expected string @@ -64,31 +59,24 @@ func TestTimestamp(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC3164Option{} if test.currentYear { opts = append(opts, WithCurrentYear()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTimestamp() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Timestamp.Format(time.RFC3339) != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Timestamp.Format(time.RFC3339)) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Timestamp.Format(time.RFC3339)) }) } } @@ -118,31 +106,24 @@ func TestHostname(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC3164Option{} if test.strictHostname { opts = append(opts, WithStrictHostname()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseHostname() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Hostname != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Hostname) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Hostname) }) } } @@ -163,32 +144,20 @@ func TestTag(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTag() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else { - if r.Tag != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Tag) - } - if r.PID != test.expectedPID { - t.Errorf("expected %s, got %s", test.expected, r.Message) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Tag) + assert.Equal(t, test.expectedPID, r.PID) }) } } @@ -207,27 +176,19 @@ func TestMessage(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseMessage() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Message != test.expected { - t.Errorf("expected message %s, got %s", test.expected, r.Tag) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Message) }) } } @@ -241,6 +202,7 @@ func TestParse(t *testing.T) { Message string PRI int } + tests := []struct { input string expected expected @@ -329,42 +291,22 @@ func TestParse(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := NewRFC3164Parser(test.opts...) + err := r.Parse([]byte(test.input)) - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error '%s', got '%s'", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: '%s'", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error '%s', got no error", test.expectedErr) - } else { - if r.Timestamp != test.expected.Timestamp { - t.Errorf("expected timestamp '%s', got '%s'", test.expected.Timestamp, r.Timestamp) - } - if r.Hostname != test.expected.Hostname { - t.Errorf("expected hostname '%s', got '%s'", test.expected.Hostname, r.Hostname) - } - if r.Tag != test.expected.Tag { - t.Errorf("expected tag '%s', got '%s'", test.expected.Tag, r.Tag) - } - if r.PID != test.expected.PID { - t.Errorf("expected pid '%s', got '%s'", test.expected.PID, r.PID) - } - if r.Message != test.expected.Message { - t.Errorf("expected message '%s', got '%s'", test.expected.Message, r.Message) - } - if r.PRI != test.expected.PRI { - t.Errorf("expected pri '%d', got '%d'", test.expected.PRI, r.PRI) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected.Timestamp, r.Timestamp) + assert.Equal(t, test.expected.Hostname, r.Hostname) + assert.Equal(t, test.expected.Tag, r.Tag) + assert.Equal(t, test.expected.PID, r.PID) + assert.Equal(t, test.expected.Message, r.Message) + assert.Equal(t, test.expected.PRI, r.PRI) }) } } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go index 42073cafbae..3805090f57f 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/perf_test.go @@ -51,7 +51,6 @@ func BenchmarkParse(b *testing.B) { } var err error for _, test := range tests { - test := test b.Run(string(test.input), func(b *testing.B) { for i := 0; i < b.N; i++ { r := NewRFC3164Parser(test.opts...) diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go index 8b71a77e2e3..639e91e1224 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go @@ -1,7 +1,7 @@ package rfc5424 import ( - "fmt" + "errors" "time" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog/internal/parser/utils" @@ -52,7 +52,7 @@ func (r *RFC5424) parsePRI() error { pri := 0 if r.buf[r.position] != '<' { - return fmt.Errorf("PRI must start with '<'") + return errors.New("PRI must start with '<'") } r.position++ @@ -64,18 +64,18 @@ func (r *RFC5424) parsePRI() error { break } if c < '0' || c > '9' { - return fmt.Errorf("PRI must be a number") + return errors.New("PRI must be a number") } pri = pri*10 + int(c-'0') r.position++ } if pri > 999 { - return fmt.Errorf("PRI must be up to 3 characters long") + return errors.New("PRI must be up to 3 characters long") } if r.position == r.len && r.buf[r.position-1] != '>' { - return fmt.Errorf("PRI must end with '>'") + return errors.New("PRI must end with '>'") } r.PRI = pri @@ -84,11 +84,11 @@ func (r *RFC5424) parsePRI() error { func (r *RFC5424) parseVersion() error { if r.buf[r.position] != '1' { - return fmt.Errorf("version must be 1") + return errors.New("version must be 1") } r.position += 2 if r.position >= r.len { - return fmt.Errorf("version must be followed by a space") + return errors.New("version must be followed by a space") } return nil } @@ -113,17 +113,17 @@ func (r *RFC5424) parseTimestamp() error { } if len(timestamp) == 0 { - return fmt.Errorf("timestamp is empty") + return errors.New("timestamp is empty") } if r.position == r.len { - return fmt.Errorf("EOL after timestamp") + return errors.New("EOL after timestamp") } date, err := time.Parse(VALID_TIMESTAMP, string(timestamp)) if err != nil { - return fmt.Errorf("timestamp is not valid") + return errors.New("timestamp is not valid") } r.Timestamp = date @@ -131,7 +131,7 @@ func (r *RFC5424) parseTimestamp() error { r.position++ if r.position >= r.len { - return fmt.Errorf("EOL after timestamp") + return errors.New("EOL after timestamp") } return nil @@ -156,11 +156,11 @@ func (r *RFC5424) parseHostname() error { } if r.strictHostname { if !utils.IsValidHostnameOrIP(string(hostname)) { - return fmt.Errorf("hostname is not valid") + return errors.New("hostname is not valid") } } if len(hostname) == 0 { - return fmt.Errorf("hostname is empty") + return errors.New("hostname is empty") } r.Hostname = string(hostname) return nil @@ -185,11 +185,11 @@ func (r *RFC5424) parseAppName() error { } if len(appname) == 0 { - return fmt.Errorf("appname is empty") + return errors.New("appname is empty") } if len(appname) > 48 { - return fmt.Errorf("appname is too long") + return errors.New("appname is too long") } r.Tag = string(appname) @@ -215,11 +215,11 @@ func (r *RFC5424) parseProcID() error { } if len(procid) == 0 { - return fmt.Errorf("procid is empty") + return errors.New("procid is empty") } if len(procid) > 128 { - return fmt.Errorf("procid is too long") + return errors.New("procid is too long") } r.PID = string(procid) @@ -245,11 +245,11 @@ func (r *RFC5424) parseMsgID() error { } if len(msgid) == 0 { - return fmt.Errorf("msgid is empty") + return errors.New("msgid is empty") } if len(msgid) > 32 { - return fmt.Errorf("msgid is too long") + return errors.New("msgid is too long") } r.MsgID = string(msgid) @@ -263,7 +263,7 @@ func (r *RFC5424) parseStructuredData() error { return nil } if r.buf[r.position] != '[' { - return fmt.Errorf("structured data must start with '[' or be '-'") + return errors.New("structured data must start with '[' or be '-'") } prev := byte(0) for r.position < r.len { @@ -281,14 +281,14 @@ func (r *RFC5424) parseStructuredData() error { } r.position++ if !done { - return fmt.Errorf("structured data must end with ']'") + return errors.New("structured data must end with ']'") } return nil } func (r *RFC5424) parseMessage() error { if r.position == r.len { - return fmt.Errorf("message is empty") + return errors.New("message is empty") } message := []byte{} @@ -305,7 +305,7 @@ func (r *RFC5424) parseMessage() error { func (r *RFC5424) Parse(message []byte) error { r.len = len(message) if r.len == 0 { - return fmt.Errorf("syslog line is empty") + return errors.New("syslog line is empty") } r.buf = message @@ -315,7 +315,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after PRI") + return errors.New("EOL after PRI") } err = r.parseVersion() @@ -324,7 +324,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after Version") + return errors.New("EOL after Version") } err = r.parseTimestamp() @@ -333,7 +333,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after Timestamp") + return errors.New("EOL after Timestamp") } err = r.parseHostname() @@ -342,7 +342,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after hostname") + return errors.New("EOL after hostname") } err = r.parseAppName() @@ -351,7 +351,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after appname") + return errors.New("EOL after appname") } err = r.parseProcID() @@ -360,7 +360,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after ProcID") + return errors.New("EOL after ProcID") } err = r.parseMsgID() @@ -369,7 +369,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after MSGID") + return errors.New("EOL after MSGID") } err = r.parseStructuredData() @@ -378,7 +378,7 @@ func (r *RFC5424) Parse(message []byte) error { } if r.position >= r.len { - return fmt.Errorf("EOL after SD") + return errors.New("EOL after SD") } err = r.parseMessage() diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go index 66a20d594e4..0938e947fe7 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestPri(t *testing.T) { @@ -25,7 +25,6 @@ func TestPri(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { r := &RFC5424{} r.buf = []byte(test.input) @@ -61,7 +60,6 @@ func TestHostname(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.input, func(t *testing.T) { opts := []RFC5424Option{} if test.strictHostname { @@ -200,7 +198,6 @@ func TestParse(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { r := NewRFC5424Parser(test.opts...) err := r.Parse([]byte(test.input)) diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go index 318571e91ee..a86c17e8ddf 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/perf_test.go @@ -92,7 +92,6 @@ func BenchmarkParse(b *testing.B) { } var err error for _, test := range tests { - test := test b.Run(test.label, func(b *testing.B) { for i := 0; i < b.N; i++ { r := NewRFC5424Parser() diff --git a/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go b/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go index 8fe717a6ab2..5e0bf8fe771 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go +++ b/pkg/acquisition/modules/syslog/internal/parser/utils/utils.go @@ -34,7 +34,7 @@ func isValidHostname(s string) bool { last := byte('.') nonNumeric := false // true once we've seen a letter or hyphen partlen := 0 - for i := 0; i < len(s); i++ { + for i := range len(s) { c := s[i] switch { default: diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 8aed2836816..5315096fb9b 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -1,6 +1,8 @@ package syslogacquisition import ( + "context" + "errors" "fmt" "net" "strings" @@ -29,10 +31,11 @@ type SyslogConfiguration struct { } type SyslogSource struct { - config SyslogConfiguration - logger *log.Entry - server *syslogserver.SyslogServer - serverTomb *tomb.Tomb + metricsLevel int + config SyslogConfiguration + logger *log.Entry + server *syslogserver.SyslogServer + serverTomb *tomb.Tomb } var linesReceived = prometheus.NewCounterVec( @@ -78,11 +81,11 @@ func (s *SyslogSource) GetAggregMetrics() []prometheus.Collector { } func (s *SyslogSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { - return fmt.Errorf("syslog datasource does not support one shot acquisition") + return errors.New("syslog datasource does not support one shot acquisition") } func (s *SyslogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - return fmt.Errorf("syslog datasource does not support one shot acquisition") + return errors.New("syslog datasource does not support one shot acquisition") } func validatePort(port int) bool { @@ -103,7 +106,7 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { } if s.config.Addr == "" { - s.config.Addr = "127.0.0.1" //do we want a usable or secure default ? + s.config.Addr = "127.0.0.1" // do we want a usable or secure default ? } if s.config.Port == 0 { s.config.Port = 514 @@ -121,10 +124,10 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { s.logger = logger s.logger.Infof("Starting syslog datasource configuration") - + s.metricsLevel = MetricsLevel err := s.UnmarshalConfig(yamlConfig) if err != nil { return err @@ -133,7 +136,7 @@ func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry) error { return nil } -func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { c := make(chan syslogserver.SyslogMessage) s.server = &syslogserver.SyslogServer{Logger: s.logger.WithField("syslog", "internal"), MaxMessageLen: s.config.MaxMessageLen} s.server.SetChannel(c) @@ -150,7 +153,8 @@ func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, - appname string, pid string, msg string) string { + appname string, pid string, msg string, +) string { ret := "" if !ts.IsZero() { ret += ts.Format("Jan 2 15:04:05") @@ -176,7 +180,6 @@ func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, ret += msg } return ret - } func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c chan syslogserver.SyslogMessage) error { @@ -198,7 +201,9 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha logger := s.logger.WithField("client", syslogLine.Client) logger.Tracef("raw: %s", syslogLine) - linesReceived.With(prometheus.Labels{"source": syslogLine.Client}).Inc() + if s.metricsLevel != configuration.METRICS_NONE { + linesReceived.With(prometheus.Labels{"source": syslogLine.Client}).Inc() + } p := rfc3164.NewRFC3164Parser(rfc3164.WithCurrentYear()) err := p.Parse(syslogLine.Message) if err != nil { @@ -211,10 +216,14 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha continue } line = s.buildLogFromSyslog(p2.Timestamp, p2.Hostname, p2.Tag, p2.PID, p2.Message) - linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc5424"}).Inc() + if s.metricsLevel != configuration.METRICS_NONE { + linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc5424"}).Inc() + } } else { line = s.buildLogFromSyslog(p.Timestamp, p.Hostname, p.Tag, p.PID, p.Message) - linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc3164"}).Inc() + if s.metricsLevel != configuration.METRICS_NONE { + linesParsed.With(prometheus.Labels{"source": syslogLine.Client, "type": "rfc3164"}).Inc() + } } line = strings.TrimSuffix(line, "\n") diff --git a/pkg/acquisition/modules/syslog/syslog_test.go b/pkg/acquisition/modules/syslog/syslog_test.go index 1d2ba3fb648..57fa3e8747b 100644 --- a/pkg/acquisition/modules/syslog/syslog_test.go +++ b/pkg/acquisition/modules/syslog/syslog_test.go @@ -1,19 +1,21 @@ package syslogacquisition import ( + "context" "fmt" "net" "runtime" "testing" "time" - "github.com/crowdsecurity/go-cs-lib/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestConfigure(t *testing.T) { @@ -51,12 +53,10 @@ listen_addr: 10.0.0`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "syslog", - }) + subLogger := log.WithField("type", "syslog") for _, test := range tests { s := SyslogSource{} - err := s.Configure([]byte(test.config), subLogger) + err := s.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) cstest.AssertErrorContains(t, err, test.expectedErr) } } @@ -81,6 +81,7 @@ func writeToSyslog(logs []string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -101,8 +102,10 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 2, - logs: []string{`<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, - `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`}, + logs: []string{ + `<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, + `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`, + }, }, { name: "RFC3164", @@ -110,10 +113,12 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 3, - logs: []string{`<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, + logs: []string{ + `<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, `<13>May 18 12:37:56 mantis sshd[49340]: blabla2`, `<13>May 18 12:37:56 mantis sshd: blabla2`, - `<13>May 18 12:37:56 mantis sshd`}, + `<13>May 18 12:37:56 mantis sshd`, + }, }, } if runtime.GOOS != "windows" { @@ -131,19 +136,16 @@ listen_addr: 127.0.0.1`, } for _, ts := range tests { - ts := ts t.Run(ts.name, func(t *testing.T) { - subLogger := log.WithFields(log.Fields{ - "type": "syslog", - }) + subLogger := log.WithField("type", "syslog") s := SyslogSource{} - err := s.Configure([]byte(ts.config), subLogger) + err := s.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) if err != nil { t.Fatalf("could not configure syslog source : %s", err) } tomb := tomb.Tomb{} out := make(chan types.Event) - err = s.StreamingAcquisition(out, &tomb) + err = s.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if ts.expectedErr != "" { return diff --git a/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx new file mode 100644 index 00000000000..2c4f8b0f680 Binary files /dev/null and b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx differ diff --git a/pkg/acquisition/modules/wineventlog/wineventlog.go b/pkg/acquisition/modules/wineventlog/wineventlog.go index f0eca5d13d7..6d522d8d8cb 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "errors" "github.com/prometheus/client_golang/prometheus" @@ -23,7 +24,7 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLevel int) error { return nil } @@ -59,7 +60,7 @@ func (w *WinEventLogSource) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index ee69dc35cdd..ca40363155b 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -1,10 +1,13 @@ package wineventlogacquisition import ( + "context" "encoding/xml" "errors" "fmt" + "net/url" "runtime" + "strconv" "strings" "syscall" "time" @@ -29,16 +32,17 @@ type WinEventLogConfiguration struct { EventLevel string `yaml:"event_level"` EventIDs []int `yaml:"event_ids"` XPathQuery string `yaml:"xpath_query"` - EventFile string `yaml:"event_file"` + EventFile string PrettyName string `yaml:"pretty_name"` } type WinEventLogSource struct { - config WinEventLogConfiguration - logger *log.Entry - evtConfig *winlog.SubscribeConfig - query string - name string + metricsLevel int + config WinEventLogConfiguration + logger *log.Entry + evtConfig *winlog.SubscribeConfig + query string + name string } type QueryList struct { @@ -46,10 +50,13 @@ type QueryList struct { } type Select struct { - Path string `xml:"Path,attr"` + Path string `xml:"Path,attr,omitempty"` Query string `xml:",chardata"` } +// 0 identifies the local machine in windows APIs +const localMachine = 0 + var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_winevtlogsource_hits_total", @@ -148,7 +155,7 @@ func (w *WinEventLogSource) buildXpathQuery() (string, error) { queryList := QueryList{Select: Select{Path: w.config.EventChannel, Query: query}} xpathQuery, err := xml.Marshal(queryList) if err != nil { - w.logger.Errorf("Marshal failed: %v", err) + w.logger.Errorf("Serialize failed: %v", err) return "", err } w.logger.Debugf("xpathQuery: %s", xpathQuery) @@ -188,7 +195,9 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error continue } for _, event := range renderedEvents { - linesRead.With(prometheus.Labels{"source": w.name}).Inc() + if w.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": w.name}).Inc() + } l := types.Line{} l.Raw = event l.Module = w.GetName() @@ -208,20 +217,28 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error } } -func (w *WinEventLogSource) generateConfig(query string) (*winlog.SubscribeConfig, error) { +func (w *WinEventLogSource) generateConfig(query string, live bool) (*winlog.SubscribeConfig, error) { var config winlog.SubscribeConfig var err error - // Create a subscription signaler. - config.SignalEvent, err = windows.CreateEvent( - nil, // Default security descriptor. - 1, // Manual reset. - 1, // Initial state is signaled. - nil) // Optional name. - if err != nil { - return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + if live { + // Create a subscription signaler. + config.SignalEvent, err = windows.CreateEvent( + nil, // Default security descriptor. + 1, // Manual reset. + 1, // Initial state is signaled. + nil) // Optional name. + if err != nil { + return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + } + config.Flags = wevtapi.EvtSubscribeToFutureEvents + } else { + config.ChannelPath, err = syscall.UTF16PtrFromString(w.config.EventFile) + if err != nil { + return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) + } + config.Flags = wevtapi.EvtQueryFilePath | wevtapi.EvtQueryForwardDirection } - config.Flags = wevtapi.EvtSubscribeToFutureEvents config.Query, err = syscall.UTF16PtrFromString(query) if err != nil { return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) @@ -243,11 +260,11 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { } if w.config.EventChannel != "" && w.config.XPathQuery != "" { - return fmt.Errorf("event_channel and xpath_query are mutually exclusive") + return errors.New("event_channel and xpath_query are mutually exclusive") } if w.config.EventChannel == "" && w.config.XPathQuery == "" { - return fmt.Errorf("event_channel or xpath_query must be set") + return errors.New("event_channel or xpath_query must be set") } w.config.Mode = configuration.TAIL_MODE @@ -270,15 +287,16 @@ func (w *WinEventLogSource) UnmarshalConfig(yamlConfig []byte) error { return nil } -func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) error { +func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { w.logger = logger + w.metricsLevel = MetricsLevel err := w.UnmarshalConfig(yamlConfig) if err != nil { return err } - w.evtConfig, err = w.generateConfig(w.query) + w.evtConfig, err = w.generateConfig(w.query, true) if err != nil { return err } @@ -287,6 +305,78 @@ func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry) erro } func (w *WinEventLogSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { + if !strings.HasPrefix(dsn, "wineventlog://") { + return fmt.Errorf("invalid DSN %s for wineventlog source, must start with wineventlog://", dsn) + } + + w.logger = logger + w.config = WinEventLogConfiguration{} + + dsn = strings.TrimPrefix(dsn, "wineventlog://") + + args := strings.Split(dsn, "?") + + if args[0] == "" { + return errors.New("empty wineventlog:// DSN") + } + + if len(args) > 2 { + return errors.New("too many arguments in DSN") + } + + w.config.EventFile = args[0] + + if len(args) == 2 && args[1] != "" { + params, err := url.ParseQuery(args[1]) + if err != nil { + return fmt.Errorf("failed to parse DSN parameters: %w", err) + } + + for key, value := range params { + switch key { + case "log_level": + if len(value) != 1 { + return errors.New("log_level must be a single value") + } + lvl, err := log.ParseLevel(value[0]) + if err != nil { + return fmt.Errorf("failed to parse log_level: %s", err) + } + w.logger.Logger.SetLevel(lvl) + case "event_id": + for _, id := range value { + evtid, err := strconv.Atoi(id) + if err != nil { + return fmt.Errorf("failed to parse event_id: %s", err) + } + w.config.EventIDs = append(w.config.EventIDs, evtid) + } + case "event_level": + if len(value) != 1 { + return errors.New("event_level must be a single value") + } + w.config.EventLevel = value[0] + } + } + } + + var err error + + //FIXME: handle custom xpath query + w.query, err = w.buildXpathQuery() + + if err != nil { + return fmt.Errorf("buildXpathQuery failed: %w", err) + } + + w.logger.Debugf("query: %s\n", w.query) + + w.evtConfig, err = w.generateConfig(w.query, false) + + if err != nil { + return fmt.Errorf("generateConfig failed: %w", err) + } + return nil } @@ -295,10 +385,57 @@ func (w *WinEventLogSource) GetMode() string { } func (w *WinEventLogSource) SupportedModes() []string { - return []string{configuration.TAIL_MODE} + return []string{configuration.TAIL_MODE, configuration.CAT_MODE} } func (w *WinEventLogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + + handle, err := wevtapi.EvtQuery(localMachine, w.evtConfig.ChannelPath, w.evtConfig.Query, w.evtConfig.Flags) + + if err != nil { + return fmt.Errorf("EvtQuery failed: %v", err) + } + + defer winlog.Close(handle) + + publisherCache := make(map[string]windows.Handle) + defer func() { + for _, h := range publisherCache { + winlog.Close(h) + } + }() + +OUTER_LOOP: + for { + select { + case <-t.Dying(): + w.logger.Infof("wineventlog is dying") + return nil + default: + evts, err := w.getXMLEvents(w.evtConfig, publisherCache, handle, 500) + if err == windows.ERROR_NO_MORE_ITEMS { + log.Info("No more items") + break OUTER_LOOP + } else if err != nil { + return fmt.Errorf("getXMLEvents failed: %v", err) + } + w.logger.Debugf("Got %d events", len(evts)) + for _, evt := range evts { + w.logger.Tracef("Event: %s", evt) + if w.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": w.name}).Inc() + } + l := types.Line{} + l.Raw = evt + l.Module = w.GetName() + l.Labels = w.config.Labels + l.Time = time.Now() + l.Src = w.name + l.Process = true + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + } + } + } return nil } @@ -321,7 +458,7 @@ func (w *WinEventLogSource) CanRun() error { return nil } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/wineventlog/streaming") return w.getEvents(out, t) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go similarity index 65% rename from pkg/acquisition/modules/wineventlog/wineventlog_test.go rename to pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go index 053ba88b52d..9afef963669 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go @@ -3,10 +3,11 @@ package wineventlogacquisition import ( - "runtime" + "context" "testing" "time" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" @@ -17,9 +18,8 @@ import ( ) func TestBadConfiguration(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedErr string @@ -53,20 +53,17 @@ xpath_query: test`, }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") for _, test := range tests { f := WinEventLogSource{} - err := f.Configure([]byte(test.config), subLogger) + err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) assert.Contains(t, err.Error(), test.expectedErr) } } func TestQueryBuilder(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedQuery string @@ -112,12 +109,10 @@ event_level: bla`, expectedErr: "invalid log level", }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") for _, test := range tests { f := WinEventLogSource{} - f.Configure([]byte(test.config), subLogger) + f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) q, err := f.buildXpathQuery() if test.expectedErr != "" { if err == nil { @@ -132,9 +127,8 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + ctx := context.Background() tests := []struct { config string @@ -180,9 +174,7 @@ event_ids: expectedLines: nil, }, } - subLogger := log.WithFields(log.Fields{ - "type": "windowseventlog", - }) + subLogger := log.WithField("type", "windowseventlog") evthandler, err := eventlog.Open("Application") @@ -194,8 +186,8 @@ event_ids: to := &tomb.Tomb{} c := make(chan types.Event) f := WinEventLogSource{} - f.Configure([]byte(test.config), subLogger) - f.StreamingAcquisition(c, to) + f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) + f.StreamingAcquisition(ctx, c, to) time.Sleep(time.Second) lines := test.expectedLines go func() { @@ -230,3 +222,82 @@ event_ids: to.Wait() } } + +func TestOneShotAcquisition(t *testing.T) { + tests := []struct { + name string + dsn string + expectedCount int + expectedErr string + expectedConfigureErr string + }{ + { + name: "non-existing file", + dsn: `wineventlog://foo.evtx`, + expectedCount: 0, + expectedErr: "The system cannot find the file specified.", + }, + { + name: "empty DSN", + dsn: `wineventlog://`, + expectedCount: 0, + expectedConfigureErr: "empty wineventlog:// DSN", + }, + { + name: "existing file", + dsn: `wineventlog://test_files/Setup.evtx`, + expectedCount: 24, + expectedErr: "", + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2`, + expectedCount: 1, + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2&event_id=3`, + expectedCount: 24, + }, + } + + exprhelpers.Init(nil) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + lineCount := 0 + to := &tomb.Tomb{} + c := make(chan types.Event) + f := WinEventLogSource{} + err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "wineventlog"}, log.WithField("type", "windowseventlog"), "") + + if test.expectedConfigureErr != "" { + assert.Contains(t, err.Error(), test.expectedConfigureErr) + return + } + + require.NoError(t, err) + + go func() { + for { + select { + case <-c: + lineCount++ + case <-to.Dying(): + return + } + } + }() + + err = f.OneShotAcquisition(c, to) + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } else { + require.NoError(t, err) + + time.Sleep(2 * time.Second) + assert.Equal(t, test.expectedCount, lineCount) + } + }) + } +} diff --git a/pkg/acquisition/s3.go b/pkg/acquisition/s3.go new file mode 100644 index 00000000000..73343b0408d --- /dev/null +++ b/pkg/acquisition/s3.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_s3 + +package acquisition + +import ( + s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("s3", func() DataSource { return &s3acquisition.S3Source{} }) +} diff --git a/pkg/acquisition/syslog.go b/pkg/acquisition/syslog.go new file mode 100644 index 00000000000..f62cc23b916 --- /dev/null +++ b/pkg/acquisition/syslog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_syslog + +package acquisition + +import ( + syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("syslog", func() DataSource { return &syslogacquisition.SyslogSource{} }) +} diff --git a/pkg/acquisition/wineventlog.go b/pkg/acquisition/wineventlog.go new file mode 100644 index 00000000000..0c4889a3f5c --- /dev/null +++ b/pkg/acquisition/wineventlog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_wineventlog + +package acquisition + +import ( + wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("wineventlog", func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }) +} diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index 7586e7cb4af..16ebc6d0ac2 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -6,8 +6,8 @@ import ( "slices" "strconv" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" @@ -16,12 +16,10 @@ import ( ) const ( - maxContextValueLen = 4000 + MaxContextValueLen = 4000 ) -var ( - alertContext = Context{} -) +var alertContext = Context{} type Context struct { ContextToSend map[string][]string @@ -34,25 +32,27 @@ func ValidateContextExpr(key string, expressions []string) error { for _, expression := range expressions { _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' failed: %v", expression, err) + return fmt.Errorf("compilation of '%s' failed: %w", expression, err) } } + return nil } func NewAlertContext(contextToSend map[string][]string, valueLength int) error { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return fmt.Errorf("couldn't create logger for alert context: %s", err) + return fmt.Errorf("couldn't create logger for alert context: %w", err) } if valueLength == 0 { - clog.Debugf("No console context value length provided, using default: %d", maxContextValueLen) - valueLength = maxContextValueLen + clog.Debugf("No console context value length provided, using default: %d", MaxContextValueLen) + valueLength = MaxContextValueLen } - if valueLength > maxContextValueLen { - clog.Debugf("Provided console context value length (%d) is higher than the maximum, using default: %d", valueLength, maxContextValueLen) - valueLength = maxContextValueLen + + if valueLength > MaxContextValueLen { + clog.Debugf("Provided console context value length (%d) is higher than the maximum, using default: %d", valueLength, MaxContextValueLen) + valueLength = MaxContextValueLen } alertContext = Context{ @@ -74,8 +74,9 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { for _, value := range values { valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' context value failed: %v", value, err) + return fmt.Errorf("compilation of '%s' context value failed: %w", value, err) } + alertContext.ContextToSendCompiled[key] = append(alertContext.ContextToSendCompiled[key], valueCompiled) alertContext.ContextToSend[key] = append(alertContext.ContextToSend[key], value) } @@ -84,17 +85,14 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { return nil } -func truncate(values []string, contextValueLen int) (string, error) { - var ret string +func TruncateContext(values []string, contextValueLen int) (string, error) { valueByte, err := json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } - ret = string(valueByte) - for { - if len(ret) <= contextValueLen { - break - } + + ret := string(valueByte) + for len(ret) > contextValueLen { // if there is only 1 value left and that the size is too big, truncate it if len(values) == 1 { valueToTruncate := values[0] @@ -106,12 +104,15 @@ func truncate(values []string, contextValueLen int) (string, error) { // if there is multiple value inside, just remove the last one values = values[:len(values)-1] } + valueByte, err = json.Marshal(values) if err != nil { - return "", fmt.Errorf("unable to dump metas: %s", err) + return "", fmt.Errorf("unable to dump metas: %w", err) } + ret = string(valueByte) } + return ret, nil } @@ -120,41 +121,49 @@ func EventToContext(events []types.Event) (models.Meta, []error) { metas := make([]*models.MetaItems0, 0) tmpContext := make(map[string][]string) + for _, evt := range events { for key, values := range alertContext.ContextToSendCompiled { if _, ok := tmpContext[key]; !ok { tmpContext[key] = make([]string, 0) } + for _, value := range values { var val string + output, err := expr.Run(value, map[string]interface{}{"evt": evt}) if err != nil { - errors = append(errors, fmt.Errorf("failed to get value for %s : %v", key, err)) + errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err)) continue } + switch out := output.(type) { case string: val = out case int: val = strconv.Itoa(out) default: - errors = append(errors, fmt.Errorf("unexpected return type for %s : %T", key, output)) + errors = append(errors, fmt.Errorf("unexpected return type for %s: %T", key, output)) continue } + if val != "" && !slices.Contains(tmpContext[key], val) { tmpContext[key] = append(tmpContext[key], val) } } } } + for key, values := range tmpContext { if len(values) == 0 { continue } - valueStr, err := truncate(values, alertContext.ContextValueLen) + + valueStr, err := TruncateContext(values, alertContext.ContextValueLen) if err != nil { - log.Warningf(err.Error()) + log.Warning(err.Error()) } + meta := models.MetaItems0{ Key: key, Value: valueStr, @@ -163,5 +172,6 @@ func EventToContext(events []types.Event) (models.Meta, []error) { } ret := models.Meta(metas) + return ret, errors } diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index 8b598eab86c..c111d1bbcfb 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -4,10 +4,11 @@ import ( "fmt" "testing" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestNewAlertContext(t *testing.T) { diff --git a/pkg/alertcontext/config.go b/pkg/alertcontext/config.go index 74ca1523a7d..6ef877619e4 100644 --- a/pkg/alertcontext/config.go +++ b/pkg/alertcontext/config.go @@ -98,20 +98,14 @@ func addContextFromFile(toSend map[string][]string, filePath string) error { return nil } - // LoadConsoleContext loads the context from the hub (if provided) and the file console_context_path. func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { c.Crowdsec.ContextToSend = make(map[string][]string, 0) if hub != nil { - items, err := hub.GetInstalledItems(cwhub.CONTEXTS) - if err != nil { - return err - } - - for _, item := range items { + for _, item := range hub.GetInstalledByType(cwhub.CONTEXTS, true) { // context in item files goes under the key 'context' - if err = addContextFromItem(c.Crowdsec.ContextToSend, item); err != nil { + if err := addContextFromItem(c.Crowdsec.ContextToSend, item); err != nil { return err } } @@ -139,7 +133,7 @@ func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { feedback, err := json.Marshal(c.Crowdsec.ContextToSend) if err != nil { - return fmt.Errorf("marshaling console context: %s", err) + return fmt.Errorf("serializing console context: %s", err) } log.Debugf("console context to send: %s", feedback) diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index ad75dd39342..a3da84d306e 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -10,8 +10,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -// type ApiAlerts service - type AlertsService service type AlertsListOpts struct { diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index 31a947556bb..0d1ff41685f 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -13,7 +13,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -35,7 +34,6 @@ func TestAlertsListAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -180,16 +178,16 @@ func TestAlertsListAsMachine(t *testing.T) { }, } - //log.Debugf("data : -> %s", spew.Sdump(alerts)) - //log.Debugf("resp : -> %s", spew.Sdump(resp)) - //log.Debugf("expected : -> %s", spew.Sdump(expected)) - //first one returns data + // log.Debugf("data : -> %s", spew.Sdump(alerts)) + // log.Debugf("resp : -> %s", spew.Sdump(resp)) + // log.Debugf("expected : -> %s", spew.Sdump(expected)) + // first one returns data alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, expected, *alerts) - //this one doesn't + // this one doesn't filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err = client.Alerts.List(context.Background(), filter) @@ -214,7 +212,6 @@ func TestAlertsGetAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -360,7 +357,7 @@ func TestAlertsGetAsMachine(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *alerts) - //fail + // fail _, _, err = client.Alerts.GetByID(context.Background(), 2) cstest.RequireErrorMessage(t, err, "API error: object not found") } @@ -388,7 +385,6 @@ func TestAlertsCreateAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -430,7 +426,6 @@ func TestAlertsDeleteAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index 71b0e273105..193486ff065 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -2,6 +2,7 @@ package apiclient import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -26,18 +27,21 @@ type JWTTransport struct { URL *url.URL VersionPrefix string UserAgent string + RetryConfig *RetryConfig // Transport is the underlying HTTP transport to use when making requests. // It will default to http.DefaultTransport if nil. Transport http.RoundTripper - UpdateScenario func() ([]string, error) + UpdateScenario func(context.Context) ([]string, error) refreshTokenMutex sync.Mutex } func (t *JWTTransport) refreshJwtToken() error { var err error + ctx := context.TODO() + if t.UpdateScenario != nil { - t.Scenarios, err = t.UpdateScenario() + t.Scenarios, err = t.UpdateScenario(ctx) if err != nil { return fmt.Errorf("can't update scenario list: %w", err) } @@ -70,9 +74,14 @@ func (t *JWTTransport) refreshJwtToken() error { req.Header.Add("Content-Type", "application/json") + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + client := &http.Client{ Transport: &retryRoundTripper{ - next: http.DefaultTransport, + next: transport, maxAttempts: 5, withBackOff: true, retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, @@ -130,52 +139,97 @@ func (t *JWTTransport) refreshJwtToken() error { return nil } -// RoundTrip implements the RoundTripper interface. -func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI - // we use a mutex to avoid this - // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) +func (t *JWTTransport) needsTokenRefresh() bool { + return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) +} + +// prepareRequest returns a copy of the request with the necessary authentication headers. +func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless + // and will cause overload on CAPI. We use a mutex to avoid this. t.refreshTokenMutex.Lock() - if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { - if err := t.refreshJwtToken(); err != nil { - t.refreshTokenMutex.Unlock() + defer t.refreshTokenMutex.Unlock() + // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, + // and it leads to do 2 requests instead of one (refresh + actual login request). + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { + if err := t.refreshJwtToken(); err != nil { return nil, err } } - t.refreshTokenMutex.Unlock() if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + req.Header.Add("Authorization", "Bearer "+t.Token) - if log.GetLevel() >= log.TraceLevel { - //requestToDump := cloneRequest(req) - dump, _ := httputil.DumpRequest(req, true) - log.Tracef("req-jwt: %s", string(dump)) - } + return req, nil +} - // Make the HTTP request. - resp, err := t.transport().RoundTrip(req) - if log.GetLevel() >= log.TraceLevel { - dump, _ := httputil.DumpResponse(resp, true) - log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) - } +// RoundTrip implements the RoundTripper interface. +func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if err != nil { - // we had an error (network error for example, or 401 because token is refused), reset the token? - t.Token = "" + var resp *http.Response + attemptsCount := make(map[int]int) - return resp, fmt.Errorf("performing jwt auth: %w", err) - } + for { + if log.GetLevel() >= log.TraceLevel { + // requestToDump := cloneRequest(req) + dump, _ := httputil.DumpRequest(req, true) + log.Tracef("req-jwt: %s", string(dump)) + } + // Make the HTTP request. + clonedReq := cloneRequest(req) - if resp != nil { - log.Debugf("resp-jwt: %d", resp.StatusCode) - } + clonedReq, err := t.prepareRequest(clonedReq) + if err != nil { + return nil, err + } + + resp, err = t.transport().RoundTrip(clonedReq) + if log.GetLevel() >= log.TraceLevel { + dump, _ := httputil.DumpResponse(resp, true) + log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) + } + + if err != nil { + // we had an error (network error for example), reset the token? + t.ResetToken() + return resp, fmt.Errorf("performing jwt auth: %w", err) + } + + if resp != nil { + log.Debugf("resp-jwt: %d", resp.StatusCode) + } + + config, shouldRetry := t.RetryConfig.StatusCodeConfig[resp.StatusCode] + if !shouldRetry { + break + } + + if attemptsCount[resp.StatusCode] >= config.MaxAttempts { + log.Infof("max attempts reached for status code %d", resp.StatusCode) + break + } + + if config.InvalidateToken { + log.Debugf("invalidating token for status code %d", resp.StatusCode) + t.ResetToken() + } + log.Debugf("retrying request to %s", req.URL.String()) + attemptsCount[resp.StatusCode]++ + log.Infof("attempt %d out of %d", attemptsCount[resp.StatusCode], config.MaxAttempts) + + if config.Backoff { + backoff := 2*attemptsCount[resp.StatusCode] + 5 + log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, attemptsCount[resp.StatusCode], config.MaxAttempts) + time.Sleep(time.Duration(backoff) * time.Second) + } + } return resp, nil + } func (t *JWTTransport) Client() *http.Client { @@ -189,29 +243,11 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } -// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded. +// transport() returns a round tripper that retries once when the status is unauthorized, +// and 5 times when the infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { - transport := t.Transport - if transport == nil { - transport = http.DefaultTransport - } - - return &retryRoundTripper{ - next: &retryRoundTripper{ - next: transport, - maxAttempts: 5, - withBackOff: true, - retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout}, - }, - maxAttempts: 2, - withBackOff: false, - retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden}, - onBeforeRequest: func(attempt int) { - // reset the token only in the second attempt as this is when we know we had a 401 or 403 - // the second attempt is supposed to refresh the token - if attempt > 0 { - t.ResetToken() - } - }, + if t.Transport != nil { + return t.Transport } + return http.DefaultTransport } diff --git a/pkg/apiclient/auth_retry.go b/pkg/apiclient/auth_retry.go index 8ec8823f6e7..a17725439bc 100644 --- a/pkg/apiclient/auth_retry.go +++ b/pkg/apiclient/auth_retry.go @@ -41,7 +41,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) maxAttempts = 1 } - for i := 0; i < maxAttempts; i++ { + for i := range maxAttempts { if i > 0 { if r.withBackOff { //nolint:gosec diff --git a/pkg/apiclient/auth_service.go b/pkg/apiclient/auth_service.go index e4350385237..e7a423cfd95 100644 --- a/pkg/apiclient/auth_service.go +++ b/pkg/apiclient/auth_service.go @@ -8,8 +8,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -// type ApiAlerts service - type AuthService service // Don't add it to the models, as they are used with LAPI, but the enroll endpoint is specific to CAPI diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index f5de827a121..d22c9394014 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -14,8 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/version" - "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -37,11 +35,13 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() var payload BasicMockPayload + err := json.Unmarshal([]byte(newStr), &payload) if err != nil || payload.MachineID == "" || payload.Password == "" { log.Printf("Bad payload") @@ -49,8 +49,8 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { } var responseBody string - responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] + responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] if !hasFoundErrorMock { responseCode = http.StatusOK responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}` @@ -77,7 +77,7 @@ func TestWatcherRegister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) @@ -88,12 +88,13 @@ func TestWatcherRegister(t *testing.T) { clientconfig := Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", } - client, err := RegisterClient(&clientconfig, &http.Client{}) + ctx := context.Background() + + client, err := RegisterClient(ctx, &clientconfig, &http.Client{}) require.NoError(t, err) log.Printf("->%T", client) @@ -103,7 +104,7 @@ func TestWatcherRegister(t *testing.T) { for _, errorCodeToTest := range errorCodesToTest { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) - client, err = RegisterClient(&clientconfig, &http.Client{}) + client, err = RegisterClient(ctx, &clientconfig, &http.Client{}) require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest) require.Error(t, err, "error expected for the response code %d", errorCodeToTest) } @@ -114,7 +115,7 @@ func TestWatcherAuth(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers/login") log.Printf("URL is %s", urlx) @@ -122,11 +123,10 @@ func TestWatcherAuth(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok auth + // ok auth clientConfig := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, @@ -162,7 +162,7 @@ func TestWatcherAuth(t *testing.T) { bodyBytes, err := io.ReadAll(resp.Response.Body) require.NoError(t, err) - log.Printf(string(bodyBytes)) + log.Print(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) } @@ -175,7 +175,7 @@ func TestWatcherUnregister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") @@ -185,6 +185,7 @@ func TestWatcherUnregister(t *testing.T) { mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) @@ -207,7 +208,6 @@ func TestWatcherUnregister(t *testing.T) { mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, @@ -230,6 +230,7 @@ func TestWatcherEnroll(t *testing.T) { mux.HandleFunc("/watchers/enroll", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() @@ -261,7 +262,6 @@ func TestWatcherEnroll(t *testing.T) { mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index b183a8c7909..47d97a28344 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,14 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" - "io" + "net" "net/http" "net/url" + "strings" "github.com/golang-jwt/jwt/v4" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -39,6 +40,7 @@ type ApiClient struct { Metrics *MetricsService Signal *SignalService HeartBeat *HeartBeatService + UsageMetrics *UsageMetricsService } func (a *ApiClient) GetClient() *http.Client { @@ -65,16 +67,39 @@ type service struct { } func NewClient(config *Config) (*ApiClient, error) { + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() + } + t := &JWTTransport{ MachineID: &config.MachineID, Password: &config.Password, Scenarios: config.Scenarios, - URL: config.URL, - UserAgent: config.UserAgent, + UserAgent: userAgent, VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, + RetryConfig: NewRetryConfig( + WithStatusCodeConfig(http.StatusUnauthorized, 2, false, true), + WithStatusCodeConfig(http.StatusForbidden, 2, false, true), + WithStatusCodeConfig(http.StatusTooManyRequests, 5, true, false), + WithStatusCodeConfig(http.StatusServiceUnavailable, 5, true, false), + WithStatusCodeConfig(http.StatusGatewayTimeout, 5, true, false), + ), + } + + transport, baseURL := createTransport(config.URL) + if transport != nil { + t.Transport = transport + } else { + // can be httpmock.MockTransport + if ht, ok := http.DefaultTransport.(*http.Transport); ok { + t.Transport = ht.Clone() + } } + t.URL = baseURL + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool @@ -82,11 +107,11 @@ func NewClient(config *Config) (*ApiClient, error) { tlsconfig.Certificates = []tls.Certificate{*Cert} } - if ht, ok := http.DefaultTransport.(*http.Transport); ok { - ht.TLSClientConfig = &tlsconfig + if t.Transport != nil { + t.Transport.(*http.Transport).TLSClientConfig = &tlsconfig } - c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} + c := &ApiClient{client: t.Client(), BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) @@ -95,28 +120,40 @@ func NewClient(config *Config) (*ApiClient, error) { c.Signal = (*SignalService)(&c.common) c.DecisionDelete = (*DecisionDeleteService)(&c.common) c.HeartBeat = (*HeartBeatService)(&c.common) + c.UsageMetrics = (*UsageMetricsService)(&c.common) return c, nil } func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) { + transport, baseURL := createTransport(URL) + if client == nil { client = &http.Client{} - if ht, ok := http.DefaultTransport.(*http.Transport); ok { - tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} - tlsconfig.RootCAs = CaCertPool + if transport != nil { + client.Transport = transport + } else { + if ht, ok := http.DefaultTransport.(*http.Transport); ok { + ht = ht.Clone() + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + tlsconfig.RootCAs = CaCertPool - if Cert != nil { - tlsconfig.Certificates = []tls.Certificate{*Cert} - } + if Cert != nil { + tlsconfig.Certificates = []tls.Certificate{*Cert} + } - ht.TLSClientConfig = &tlsconfig - client.Transport = ht + ht.TLSClientConfig = &tlsconfig + client.Transport = ht + } } } - c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix} + if userAgent == "" { + userAgent = useragent.Default() + } + + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: prefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) @@ -125,31 +162,46 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt c.Signal = (*SignalService)(&c.common) c.DecisionDelete = (*DecisionDeleteService)(&c.common) c.HeartBeat = (*HeartBeatService)(&c.common) + c.UsageMetrics = (*UsageMetricsService)(&c.common) return c, nil } -func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { +func RegisterClient(ctx context.Context, config *Config, client *http.Client) (*ApiClient, error) { + transport, baseURL := createTransport(config.URL) + if client == nil { client = &http.Client{} + if transport != nil { + client.Transport = transport + } else { + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + if Cert != nil { + tlsconfig.RootCAs = CaCertPool + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + + client.Transport = http.DefaultTransport.(*http.Transport).Clone() + client.Transport.(*http.Transport).TLSClientConfig = &tlsconfig + } + } else if client.Transport == nil && transport != nil { + client.Transport = transport } - tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} - if Cert != nil { - tlsconfig.RootCAs = CaCertPool - tlsconfig.Certificates = []tls.Certificate{*Cert} + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() } - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig - c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) c.Auth = (*AuthService)(&c.common) - resp, err := c.Auth.RegisterWatcher(context.Background(), models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}) - /*if we have http status, return it*/ + resp, err := c.Auth.RegisterWatcher(ctx, models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken}) if err != nil { + /*if we have http status, return it*/ if resp != nil && resp.Response != nil { return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err) } @@ -160,60 +212,46 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { return c, nil } -type Response struct { - Response *http.Response - //add our pagination stuff - //NextPage int - //... -} - -type ErrorResponse struct { - models.ErrorResponse -} +func createTransport(url *url.URL) (*http.Transport, *url.URL) { + urlString := url.String() -func (e *ErrorResponse) Error() string { - err := fmt.Sprintf("API error: %s", *e.Message) - if len(e.Errors) > 0 { - err += fmt.Sprintf(" (%s)", e.Errors) + // TCP transport + if !strings.HasPrefix(urlString, "/") { + return nil, url } - return err -} + // Unix transport + url.Path = "/" + url.Host = "unix" + url.Scheme = "http" -func newResponse(r *http.Response) *Response { - return &Response{Response: r} + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(urlString, "/")) + }, + }, url } -func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { - return nil - } - - errorResponse := &ErrorResponse{} - - data, err := io.ReadAll(r.Body) - if err == nil && len(data)>0 { - err := json.Unmarshal(data, errorResponse) - if err != nil { - return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) - } - } else { - errorResponse.Message = new(string) - *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) - } +type Response struct { + Response *http.Response + // add our pagination stuff + // NextPage int + // ... +} - return errorResponse +func newResponse(r *http.Response) *Response { + return &Response{Response: r} } type ListOpts struct { - //Page int - //PerPage int + // Page int + // PerPage int } type DeleteOpts struct { - //?? + // ?? } type AddOpts struct { - //?? + // ?? } diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index a7582eaf437..45cd8410a8e 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -2,7 +2,6 @@ package apiclient import ( "context" - "fmt" "net/http" "net/url" "testing" @@ -11,21 +10,19 @@ import ( "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" - "github.com/crowdsecurity/go-cs-lib/version" ) func TestNewRequestInvalid(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + // missing slash in uri apiURL, err := url.Parse(urlx) require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -57,7 +54,6 @@ func TestNewRequestTimeout(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index dc6eae16926..d1f58f33ad2 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -3,10 +3,13 @@ package apiclient import ( "context" "fmt" + "net" "net/http" "net/http/httptest" "net/url" + "path" "runtime" + "strings" "testing" log "github.com/sirupsen/logrus" @@ -14,7 +17,6 @@ import ( "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" - "github.com/crowdsecurity/go-cs-lib/version" ) /*this is a ripoff of google/go-github approach : @@ -34,12 +36,50 @@ func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) { apiHandler := http.NewServeMux() apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) - // server is a test HTTP server used to provide mock API responses. server := httptest.NewServer(apiHandler) return mux, server.URL, server.Close } +// toUNCPath converts a Windows file path to a UNC path. +// This is necessary because the Go http package does not support Windows file paths. +func toUNCPath(path string) (string, error) { + colonIdx := strings.Index(path, ":") + if colonIdx == -1 { + return "", fmt.Errorf("invalid path format, missing drive letter: %s", path) + } + + // URL parsing does not like backslashes + remaining := strings.ReplaceAll(path[colonIdx+1:], "\\", "/") + uncPath := "//localhost/" + path[:colonIdx] + "$" + remaining + + return uncPath, nil +} + +func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { + var err error + if runtime.GOOS == "windows" { + socket, err = toUNCPath(socket) + if err != nil { + log.Fatalf("converting to UNC path: %s", err) + } + } + + mux = http.NewServeMux() + baseURLPath := "/" + urlPrefix + + apiHandler := http.NewServeMux() + apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) + + server := httptest.NewUnstartedServer(apiHandler) + l, _ := net.Listen("unix", socket) + _ = server.Listener.Close() + server.Listener = l + server.Start() + + return mux, socket, server.Close +} + func testMethod(t *testing.T, r *http.Request, want string) { t.Helper() assert.Equal(t, want, r.Method) @@ -55,7 +95,6 @@ func TestNewClientOk(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -77,6 +116,48 @@ func TestNewClientOk(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } +func TestNewClientOk_UnixSocket(t *testing.T) { + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + apiURL, err := url.Parse(urlx) + if err != nil { + t.Fatalf("parsing api url: %s", apiURL) + } + + client, err := NewClient(&Config{ + MachineID: "test_login", + Password: "test_password", + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + t.Fatalf("new api client: %s", err) + } + /*mock login*/ + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) + }) + + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusOK) + }) + + _, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) + if err != nil { + t.Fatalf("test Unable to list alerts : %+v", err) + } + + if resp.Response.StatusCode != http.StatusOK { + t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) + } +} + func TestNewClientKo(t *testing.T) { mux, urlx, teardown := setup() defer teardown() @@ -87,7 +168,6 @@ func TestNewClientKo(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -131,22 +211,50 @@ func TestNewDefaultClient(t *testing.T) { log.Printf("err-> %s", err) } +func TestNewDefaultClient_UnixSocket(t *testing.T) { + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + apiURL, err := url.Parse(urlx) + if err != nil { + t.Fatalf("parsing api url: %s", apiURL) + } + + client, err := NewDefaultClient(apiURL, "/v1", "", nil) + if err != nil { + t.Fatalf("new api client: %s", err) + } + + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"code": 401, "message" : "brr"}`)) + }) + + _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) + assert.Contains(t, err.Error(), `performing request: API error: brr`) + log.Printf("err-> %s", err) +} + func TestNewClientRegisterKO(t *testing.T) { apiURL, err := url.Parse("http://127.0.0.1:4242/") require.NoError(t, err) - _, err = RegisterClient(&Config{ + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - if runtime.GOOS != "windows" { - cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused") - } else { + if runtime.GOOS == "windows" { cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.") + } else { + cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused") } } @@ -166,10 +274,11 @@ func TestNewClientRegisterOK(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - client, err := RegisterClient(&Config{ + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) @@ -178,6 +287,42 @@ func TestNewClientRegisterOK(t *testing.T) { log.Printf("->%T", client) } +func TestNewClientRegisterOK_UnixSocket(t *testing.T) { + log.SetLevel(log.TraceLevel) + + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + + mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1") + defer teardown() + + /*mock login*/ + mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "POST") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) + }) + + apiURL, err := url.Parse(urlx) + if err != nil { + t.Fatalf("parsing api url: %s", apiURL) + } + + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ + MachineID: "test_login", + Password: "test_password", + URL: apiURL, + VersionPrefix: "v1", + }, &http.Client{}) + if err != nil { + t.Fatalf("while registering client : %s", err) + } + + log.Printf("->%T", client) +} + func TestNewClientBadAnswer(t *testing.T) { log.SetLevel(log.TraceLevel) @@ -194,12 +339,13 @@ func TestNewClientBadAnswer(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - _, err = RegisterClient(&Config{ + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value") + cstest.RequireErrorContains(t, err, "API error: http code 401, response: bad") } diff --git a/pkg/apiclient/config.go b/pkg/apiclient/config.go index 4dfeb3e863f..29a8acf185e 100644 --- a/pkg/apiclient/config.go +++ b/pkg/apiclient/config.go @@ -1,18 +1,20 @@ package apiclient import ( + "context" "net/url" "github.com/go-openapi/strfmt" ) type Config struct { - MachineID string - Password strfmt.Password - Scenarios []string - URL *url.URL - PapiURL *url.URL - VersionPrefix string - UserAgent string - UpdateScenario func() ([]string, error) + MachineID string + Password strfmt.Password + Scenarios []string + URL *url.URL + PapiURL *url.URL + VersionPrefix string + UserAgent string + RegistrationToken string + UpdateScenario func(context.Context) ([]string, error) } diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index 388a870f999..98f26cad9ae 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -144,7 +144,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) for idx, decision := range decisionsGroup.Decisions { - decision := decision // fix exportloopref linter message + decision := decision //nolint:copyloopvar // fix exportloopref linter message partialDecisions[idx] = &models.Decision{ Scenario: &scenarioDeleted, Scope: decisionsGroup.Scope, diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index fb2fb7342f7..54c44f43eda 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -2,7 +2,6 @@ package apiclient import ( "context" - "fmt" "net/http" "net/url" "testing" @@ -13,7 +12,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" @@ -27,6 +25,7 @@ func TestDecisionsList(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.URL.RawQuery == "ip=1.2.3.4" { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) @@ -35,14 +34,14 @@ func TestDecisionsList(t *testing.T) { } else { w.WriteHeader(http.StatusOK) w.Write([]byte(`null`)) - //no results + // no results } }) apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -69,7 +68,7 @@ func TestDecisionsList(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *decisions) - //Empty return + // Empty return decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")} decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter) require.NoError(t, err) @@ -86,6 +85,7 @@ func TestDecisionsStream(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { w.WriteHeader(http.StatusOK) @@ -100,6 +100,7 @@ func TestDecisionsStream(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) + if r.Method == http.MethodDelete { w.WriteHeader(http.StatusOK) } @@ -108,7 +109,7 @@ func TestDecisionsStream(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -135,14 +136,14 @@ func TestDecisionsStream(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *decisions) - //and second call, we get empty lists + // and second call, we get empty lists decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false}) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, decisions.New) assert.Empty(t, decisions.Deleted) - //delete stream + // delete stream resp, err = newcli.Decisions.StopStream(context.Background()) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) @@ -157,6 +158,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { w.WriteHeader(http.StatusOK) @@ -171,7 +173,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -221,6 +223,7 @@ func TestDecisionsStreamV3(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}], @@ -232,7 +235,7 @@ func TestDecisionsStreamV3(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -306,7 +309,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -392,7 +395,7 @@ func TestDeleteDecisions(t *testing.T) { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"nbDeleted":"1"}`)) - //w.Write([]byte(`{"message":"0 deleted alerts"}`)) + // w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) log.Printf("URL is %s", urlx) @@ -403,7 +406,6 @@ func TestDeleteDecisions(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, VersionPrefix: "v1", }) @@ -459,7 +461,6 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { o := &DecisionsStreamOpts{ Startup: tt.fields.Startup, @@ -470,6 +471,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { got, err := o.addQueryParamsToURL(baseURLString) cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } @@ -504,7 +506,6 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { // client, err := NewClient(&Config{ // MachineID: "test_login", // Password: "test_password", -// UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), // URL: apiURL, // VersionPrefix: "v1", // }) diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go new file mode 100644 index 00000000000..1b0786f9882 --- /dev/null +++ b/pkg/apiclient/resperr.go @@ -0,0 +1,61 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type ErrorResponse struct { + models.ErrorResponse +} + +func (e *ErrorResponse) Error() string { + message := ptr.OrEmpty(e.Message) + errors := "" + + if e.Errors != "" { + errors = fmt.Sprintf(" (%s)", e.Errors) + } + + if message == "" && errors == "" { + errors = "(no errors)" + } + + return fmt.Sprintf("API error: %s%s", message, errors) +} + +// CheckResponse verifies the API response and builds an appropriate Go error if necessary. +func CheckResponse(r *http.Response) error { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { + return nil + } + + ret := &ErrorResponse{} + + data, err := io.ReadAll(r.Body) + if err != nil || len(data) == 0 { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, no response body", r.StatusCode)) + return ret + } + + switch r.StatusCode { + case http.StatusUnprocessableEntity: + ret.Message = ptr.Of(fmt.Sprintf("http code %d, invalid request: %s", r.StatusCode, string(data))) + default: + // try to unmarshal and if there are no 'message' or 'errors' fields, display the body as is, + // the API is following a different convention + err := json.Unmarshal(data, ret) + if err != nil || (ret.Message == nil && ret.Errors == "") { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, response: %s", r.StatusCode, string(data))) + return ret + } + } + + return ret +} diff --git a/pkg/apiclient/retry_config.go b/pkg/apiclient/retry_config.go new file mode 100644 index 00000000000..8a0d1096f84 --- /dev/null +++ b/pkg/apiclient/retry_config.go @@ -0,0 +1,33 @@ +package apiclient + +type StatusCodeConfig struct { + MaxAttempts int + Backoff bool + InvalidateToken bool +} + +type RetryConfig struct { + StatusCodeConfig map[int]StatusCodeConfig +} + +type RetryConfigOption func(*RetryConfig) + +func NewRetryConfig(options ...RetryConfigOption) *RetryConfig { + rc := &RetryConfig{ + StatusCodeConfig: make(map[int]StatusCodeConfig), + } + for _, opt := range options { + opt(rc) + } + return rc +} + +func WithStatusCodeConfig(statusCode int, maxAttempts int, backOff bool, invalidateToken bool) RetryConfigOption { + return func(rc *RetryConfig) { + rc.StatusCodeConfig[statusCode] = StatusCodeConfig{ + MaxAttempts: maxAttempts, + Backoff: backOff, + InvalidateToken: invalidateToken, + } + } +} diff --git a/pkg/apiclient/usagemetrics.go b/pkg/apiclient/usagemetrics.go new file mode 100644 index 00000000000..1d822bb5c1e --- /dev/null +++ b/pkg/apiclient/usagemetrics.go @@ -0,0 +1,29 @@ +package apiclient + +import ( + "context" + "fmt" + "net/http" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type UsageMetricsService service + +func (s *UsageMetricsService) Add(ctx context.Context, metrics *models.AllMetrics) (interface{}, *Response, error) { + u := fmt.Sprintf("%s/usage-metrics", s.client.URLPrefix) + + req, err := s.client.NewRequest(http.MethodPost, u, &metrics) + if err != nil { + return nil, nil, err + } + + var response interface{} + + resp, err := s.client.Do(ctx, req, &response) + if err != nil { + return nil, resp, err + } + + return &response, resp, nil +} diff --git a/pkg/apiclient/useragent/useragent.go b/pkg/apiclient/useragent/useragent.go new file mode 100644 index 00000000000..5a62ce1ac06 --- /dev/null +++ b/pkg/apiclient/useragent/useragent.go @@ -0,0 +1,9 @@ +package useragent + +import ( + "github.com/crowdsecurity/go-cs-lib/version" +) + +func Default() string { + return "crowdsec/" + version.String() + "-" + version.System +} diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 5365058176d..4cc215c344f 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "net/http" @@ -25,11 +26,11 @@ type LAPI struct { DBConfig *csconfig.DatabaseCfg } -func SetupLAPITest(t *testing.T) LAPI { +func SetupLAPITest(t *testing.T, ctx context.Context) LAPI { t.Helper() - router, loginResp, config := InitMachineTest(t) + router, loginResp, config := InitMachineTest(t, ctx) - APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) return LAPI{ router: router, @@ -39,14 +40,14 @@ func SetupLAPITest(t *testing.T) LAPI { } } -func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder { +func (l *LAPI) InsertAlertFromFile(t *testing.T, ctx context.Context, path string) *httptest.ResponseRecorder { alertReader := GetAlertReaderFromFile(t, path) - return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password") + return l.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() - req, err := http.NewRequest(verb, url, body) + req, err := http.NewRequestWithContext(ctx, verb, url, body) require.NoError(t, err) switch authType { @@ -63,19 +64,19 @@ func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strin return w } -func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { - router, config := NewAPITest(t) - loginResp := LoginToTestAPI(t, router, config) +func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { + router, config := NewAPITest(t, ctx) + loginResp := LoginToTestAPI(t, ctx, router, config) return router, loginResp, config } -func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { - body := CreateTestMachine(t, router) - ValidateMachine(t, "test", config.API.Server.DbConfig) +func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, ctx, router, "") + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -92,50 +93,55 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon } func TestSimulatedAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk+simul.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json") - //exclude decision in simulation mode + // exclude decision in simulation mode - w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) - //include decision in simulation mode + // include decision in simulation mode - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) } func TestCreateAlert(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Alert with invalid format - w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") - w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password") + w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) + assert.Equal(t, + `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, + w.Body.String()) // Create Valid Alert - w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + w = lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assert.Equal(t, 201, w.Code) assert.Equal(t, `["1"]`, w.Body.String()) } func TestCreateAlertChannels(t *testing.T) { - apiServer, config := NewAPIServer(t) + ctx := context.Background() + apiServer, config := NewAPIServer(t, ctx) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() - loginResp := LoginToTestAPI(t, apiServer.router, config) + loginResp := LoginToTestAPI(t, ctx, apiServer.router, config) lapi := LAPI{router: apiServer.router, loginResp: loginResp} var ( @@ -151,221 +157,225 @@ func TestCreateAlertChannels(t *testing.T) { wg.Done() }() - go lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") wg.Wait() assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() } func TestAlertListFilters(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json") - //bad filter + // bad filter - w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) - //get without filters + // get without filters - w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) - //check alert and decision + // check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (ok) + // test decision_type filter (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (bad value) + // test decision_type filter (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scope (ok) + // test scope (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scope (bad value) + // test scope (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scenario (ok) + // test scenario (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scenario (bad value) + // test scenario (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (ok) + // test ip (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test ip (bad value) + // test ip (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (invalid value) + // test ip (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test range (ok) + // test range (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test range + // test range - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test range (invalid value) + // test range (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test since (ok) + // test since (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test since (ok but yields no results) + // test since (ok but yields no results) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test since (invalid value) + // test since (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test until (ok) + // test until (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test until (ok but no return) + // test until (ok but no return) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test until (invalid value) + // test until (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test has active decision (invalid value) + // test has active decision (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } func TestAlertBulkInsert(t *testing.T) { - lapi := SetupLAPITest(t) - //insert a bulk of 20 alerts to trigger bulk insert - lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + // insert a bulk of 20 alerts to trigger bulk insert + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_bulk.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") - w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } func TestListAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert - w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } func TestCreateAlertErrors(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") - //test invalid bearer + // test invalid bearer w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - //test invalid bearer + // test invalid bearer w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) @@ -373,12 +383,13 @@ func TestCreateAlertErrors(t *testing.T) { } func TestDeleteAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -387,7 +398,7 @@ func TestDeleteAlert(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -396,12 +407,13 @@ func TestDeleteAlert(t *testing.T) { } func TestDeleteAlertByID(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -410,7 +422,7 @@ func TestDeleteAlertByID(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -419,12 +431,13 @@ func TestDeleteAlertByID(t *testing.T) { } func TestDeleteAlertTrustedIPS(t *testing.T) { + ctx := context.Background() cfg := LoadTestConfig(t) // IPv6 mocking doesn't seem to work. // cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"} cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" - server, err := NewServer(cfg.API.Server) + server, err := NewServer(ctx, cfg.API.Server) require.NoError(t, err) err = server.InitController() @@ -433,7 +446,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { router, err := server.Router() require.NoError(t, err) - loginResp := LoginToTestAPI(t, router, cfg) + loginResp := LoginToTestAPI(t, ctx, router, cfg) lapi := LAPI{ router: router, loginResp: loginResp, @@ -441,7 +454,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeleteFailedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -453,7 +466,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeletedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -462,17 +475,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeleteFailedFromIP("4.3.2.1") assertAlertDeletedFromIP("1.2.3.4") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.0") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.1") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.255") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 883ff21298d..e6ed68a6e0d 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -10,13 +11,14 @@ import ( ) func TestAPIKey(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) - APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) // Login with empty token w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +27,7 @@ func TestAPIKey(t *testing.T) { // Login with invalid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") router.ServeHTTP(w, req) @@ -35,7 +37,7 @@ func TestAPIKey(t *testing.T) { // Login with valid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index d0b205c254d..a2fb0e85749 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -20,7 +20,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" - "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -35,26 +34,30 @@ import ( const ( // delta values must be smaller than the interval - pullIntervalDefault = time.Hour * 2 - pullIntervalDelta = 5 * time.Minute - pushIntervalDefault = time.Second * 10 - pushIntervalDelta = time.Second * 7 - metricsIntervalDefault = time.Minute * 30 - metricsIntervalDelta = time.Minute * 15 + pullIntervalDefault = time.Hour * 2 + pullIntervalDelta = 5 * time.Minute + pushIntervalDefault = time.Second * 10 + pushIntervalDelta = time.Second * 7 + metricsIntervalDefault = time.Minute * 30 + metricsIntervalDelta = time.Minute * 15 + usageMetricsInterval = time.Minute * 30 + usageMetricsIntervalDelta = time.Minute * 15 ) type apic struct { // when changing the intervals in tests, always set *First too // or they can be negative - pullInterval time.Duration - pullIntervalFirst time.Duration - pushInterval time.Duration - pushIntervalFirst time.Duration - metricsInterval time.Duration - metricsIntervalFirst time.Duration - dbClient *database.Client - apiClient *apiclient.ApiClient - AlertsAddChan chan []*models.Alert + pullInterval time.Duration + pullIntervalFirst time.Duration + pushInterval time.Duration + pushIntervalFirst time.Duration + metricsInterval time.Duration + metricsIntervalFirst time.Duration + usageMetricsInterval time.Duration + usageMetricsIntervalFirst time.Duration + dbClient *database.Client + apiClient *apiclient.ApiClient + AlertsAddChan chan []*models.Alert mu sync.Mutex pushTomb tomb.Tomb @@ -79,14 +82,14 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { return ret } -func (a *apic) FetchScenariosListFromDB() ([]string, error) { +func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } - //merge all scenarios together + // merge all scenarios together for _, v := range machines { machineScenarios := strings.Split(v.Scenarios, ",") log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID) @@ -113,7 +116,7 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ Origin: ptr.Of(*decision.Origin), Scenario: ptr.Of(*decision.Scenario), Scope: ptr.Of(*decision.Scope), - //Simulated: *decision.Simulated, + // Simulated: *decision.Simulated, Type: ptr.Of(*decision.Type), Until: decision.Until, Value: ptr.Of(*decision.Value), @@ -171,33 +174,35 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error ret := &apic{ - AlertsAddChan: make(chan []*models.Alert), - dbClient: dbClient, - mu: sync.Mutex{}, - startup: true, - credentials: config.Credentials, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, - scenarioList: make([]string, 0), - consoleConfig: consoleConfig, - pullInterval: pullIntervalDefault, - pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), - pushInterval: pushIntervalDefault, - pushIntervalFirst: randomDuration(pushIntervalDefault, pushIntervalDelta), - metricsInterval: metricsIntervalDefault, - metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta), - isPulling: make(chan bool, 1), - whitelists: apicWhitelist, + AlertsAddChan: make(chan []*models.Alert), + dbClient: dbClient, + mu: sync.Mutex{}, + startup: true, + credentials: config.Credentials, + pullTomb: tomb.Tomb{}, + pushTomb: tomb.Tomb{}, + metricsTomb: tomb.Tomb{}, + scenarioList: make([]string, 0), + consoleConfig: consoleConfig, + pullInterval: pullIntervalDefault, + pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), + pushInterval: pushIntervalDefault, + pushIntervalFirst: randomDuration(pushIntervalDefault, pushIntervalDelta), + metricsInterval: metricsIntervalDefault, + metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta), + usageMetricsInterval: usageMetricsInterval, + usageMetricsIntervalFirst: randomDuration(usageMetricsInterval, usageMetricsIntervalDelta), + isPulling: make(chan bool, 1), + whitelists: apicWhitelist, } password := strfmt.Password(config.Credentials.Password) - apiURL, err := url.Parse(config.Credentials.URL) + apiURL, err := url.Parse(config.Credentials.URL) if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err) } @@ -207,7 +212,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) } - ret.scenarioList, err = ret.FetchScenariosListFromDB() + ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx) if err != nil { return nil, fmt.Errorf("while fetching scenarios from db: %w", err) } @@ -215,7 +220,6 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con ret.apiClient, err = apiclient.NewClient(&apiclient.Config{ MachineID: config.Credentials.Login, Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), URL: apiURL, PapiURL: papiURL, VersionPrefix: "v3", @@ -228,12 +232,12 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con // The watcher will be authenticated by the RoundTripper the first time it will call CAPI // Explicit authentication will provoke a useless supplementary call to CAPI - scenarios, err := ret.FetchScenariosListFromDB() + scenarios, err := ret.FetchScenariosListFromDB(ctx) if err != nil { return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, @@ -252,7 +256,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con } // keep track of all alerts in cache and push it to CAPI every PushInterval. -func (a *apic) Push() error { +func (a *apic) Push(ctx context.Context) error { defer trace.CatchPanic("lapi/pushToAPIC") var cache models.AddSignalsRequest @@ -272,7 +276,7 @@ func (a *apic) Push() error { return nil } - go a.Send(&cache) + go a.Send(ctx, &cache) return nil case <-ticker.C: @@ -285,7 +289,7 @@ func (a *apic) Push() error { a.mu.Unlock() log.Infof("Signal push: %d signals to push", len(cacheCopy)) - go a.Send(&cacheCopy) + go a.Send(ctx, &cacheCopy) } case alerts := <-a.AlertsAddChan: var signals []*models.AddSignalsRequestItem @@ -347,7 +351,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig return true } -func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { +func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) { /*we do have a problem with this : The apic.Push background routine reads from alertToPush chan. This chan is filled by Controller.CreateAlert @@ -371,12 +375,11 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { for { if pageEnd >= len(cache) { send = cache[pageStart:] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() _, _, err := a.apiClient.Signal.Add(ctx, &send) - if err != nil { log.Errorf("sending signal to central API: %s", err) return @@ -386,14 +389,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } send = cache[pageStart:pageEnd] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() _, _, err := a.apiClient.Signal.Add(ctx, &send) - if err != nil { - //we log it here as well, because the return value of func might be discarded + // we log it here as well, because the return value of func might be discarded log.Errorf("sending signal to central API: %s", err) } @@ -402,13 +404,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -422,6 +424,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) { } func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { + ctx := context.TODO() nbDeleted := 0 for _, decision := range deletedDecisions { @@ -434,9 +437,9 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet filter["scopes"] = []string{*decision.Scope} } - dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return 0, fmt.Errorf("deleting decisions error: %w", err) + return 0, fmt.Errorf("expiring decisions error: %w", err) } dbCliDel, err := strconv.Atoi(dbCliRet) @@ -451,7 +454,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet return nbDeleted, nil } -func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { +func (a *apic) HandleDeletedDecisionsV3(ctx context.Context, deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int for _, decisions := range deletedDecisions { @@ -466,9 +469,9 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi filter["scopes"] = []string{*scope} } - dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return 0, fmt.Errorf("deleting decisions error: %w", err) + return 0, fmt.Errorf("expiring decisions error: %w", err) } dbCliDel, err := strconv.Atoi(dbCliRet) @@ -506,6 +509,7 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { if sub.Scenario == nil { log.Warningf("nil scenario in %+v", sub) } + if *sub.Scenario == *decision.Scenario { found = true break @@ -539,7 +543,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { scenario = *decision.Scenario scope = types.ListOrigin default: - // XXX: this or nil? scenario = "" scope = "" @@ -568,7 +571,7 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { // This function takes in list of parent alerts and decisions and then pairs them up. func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert { for _, decision := range decisions { - //count and create separate alerts for each list + // count and create separate alerts for each list updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1) /*CAPI might send lower case scopes, unify it.*/ @@ -580,7 +583,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio } found := false - //add the individual decisions to the right list + // add the individual decisions to the right list for idx, alert := range alerts { if *decision.Origin == types.CAPIOrigin { if *alert.Source.Scope == types.CAPIOrigin { @@ -593,6 +596,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio if *alert.Source.Scope == types.ListOrigin && *alert.Scenario == *decision.Scenario { alerts[idx].Decisions = append(alerts[idx].Decisions, decision) found = true + break } } else { @@ -611,11 +615,11 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error - //A mutex with TryLock would be a bit simpler - //But go does not guarantee that TryLock will be able to acquire the lock even if it is available + // A mutex with TryLock would be a bit simpler + // But go does not guarantee that TryLock will be able to acquire the lock even if it is available select { case a.isPulling <- true: defer func() { @@ -626,16 +630,33 @@ func (a *apic) PullTop(forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil } } + log.Debug("Acquiring lock for pullCAPI") + + err = a.dbClient.AcquirePullCAPILock(ctx) + if a.dbClient.IsLocked(err) { + log.Info("PullCAPI is already running, skipping") + return nil + } + + /*defer lock release*/ + defer func() { + log.Debug("Releasing lock for pullCAPI") + + if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil { + log.Errorf("while releasing lock: %v", err) + } + }() + log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -653,7 +674,7 @@ func (a *apic) PullTop(forcePull bool) error { addCounters, deleteCounters := makeAddAndDeleteCounters() // process deleted decisions - nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) + nbDeleted, err := a.HandleDeletedDecisionsV3(ctx, data.Deleted, deleteCounters) if err != nil { return err } @@ -667,20 +688,20 @@ func (a *apic) PullTop(forcePull bool) error { // create one alert for community blocklist using the first decision decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) - //apply APIC specific whitelists + // apply APIC specific whitelists decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters) if err != nil { return fmt.Errorf("while saving alerts: %w", err) } // update blocklists - if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } @@ -688,9 +709,9 @@ func (a *apic) PullTop(forcePull bool) error { } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert -func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) @@ -726,7 +747,7 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis if a.whitelists == nil || len(a.whitelists.Cidrs) == 0 && len(a.whitelists.Ips) == 0 { return decisions } - //deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place + // deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place outIdx := 0 for _, decision := range decisions { @@ -739,11 +760,11 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis decisions[outIdx] = decision outIdx++ } - //shrink the list, those are deleted items + // shrink the list, those are deleted items return decisions[:outIdx] } -func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { +func (a *apic) SaveAlerts(ctx context.Context, alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { for _, alert := range alertsFromCapi { setAlertScenario(alert, addCounters, deleteCounters) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) @@ -752,7 +773,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } @@ -763,13 +784,13 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string return nil } -func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) { +func (a *apic) ShouldForcePullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) (bool, error) { // we should force pull if the blocklist decisions are about to expire or there's no decision in the db alertQuery := a.dbClient.Ent.Alert.Query() alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) - alertInstance, err := alertQuery.First(context.Background()) + alertInstance, err := alertQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no alert found for %s, force refresh", *blocklist.Name) @@ -781,8 +802,8 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) - firstDecision, err := decisionQuery.First(context.Background()) + firstDecision, err := decisionQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no decision found for %s, force refresh", *blocklist.Name) @@ -800,7 +821,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -812,7 +833,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap } if !forcePull { - _forcePull, err := a.ShouldForcePullBlocklist(blocklist) + _forcePull, err := a.ShouldForcePullBlocklist(ctx, blocklist) if err != nil { return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) } @@ -828,13 +849,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } - decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } @@ -849,7 +870,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -858,13 +879,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap log.Infof("blocklist %s has no decisions", *blocklist.Name) return nil } - //apply APIC specific whitelists + // apply APIC specific whitelists decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, addCounters, nil) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, nil) if err != nil { return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) } @@ -872,7 +893,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -888,7 +909,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } @@ -897,22 +918,27 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { - if *alert.Source.Scope == types.CAPIOrigin { + switch *alert.Source.Scope { + case types.CAPIOrigin: *alert.Source.Scope = types.CommunityBlocklistPullSourceScope - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"])) - } else if *alert.Source.Scope == types.ListOrigin { + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.CAPIOrigin]["all"], + deleteCounters[types.CAPIOrigin]["all"])) + case types.ListOrigin: *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", + addCounters[types.ListOrigin][*alert.Scenario], + deleteCounters[types.ListOrigin][*alert.Scenario])) } } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false for { - scenario, err := a.FetchScenariosListFromDB() + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) } @@ -930,7 +956,7 @@ func (a *apic) Pull() error { time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -942,7 +968,7 @@ func (a *apic) Pull() error { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } @@ -974,11 +1000,12 @@ func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[strin } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { - if *origin == types.CAPIOrigin { + switch *origin { + case types.CAPIOrigin: counter[*origin]["all"] += totalDecisions - } else if *origin == types.ListOrigin { + case types.ListOrigin: counter[*origin][*scenario] += totalDecisions - } else { + default: log.Warningf("Unknown origin %s", *origin) } } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 128ce5a9639..3d9e7b28a79 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -2,20 +2,191 @@ package apiserver import ( "context" + "encoding/json" + "net/http" + "slices" + "strings" "time" log "github.com/sirupsen/logrus" - "slices" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (a *apic) GetMetrics() (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() +type dbPayload struct { + Metrics []*models.DetailedMetrics `json:"metrics"` +} + +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { + allMetrics := &models.AllMetrics{} + metricsIds := make([]int, 0) + + lps, err := a.dbClient.ListMachines(ctx) + if err != nil { + return nil, nil, err + } + + bouncers, err := a.dbClient.ListBouncers(ctx) + if err != nil { + return nil, nil, err + } + + for _, bouncer := range bouncers { + dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(ctx, bouncer.Name) + if err != nil { + log.Errorf("unable to get bouncer usage metrics: %s", err) + continue + } + + rcMetrics := models.RemediationComponentsMetrics{} + + rcMetrics.Os = &models.OSversion{ + Name: ptr.Of(bouncer.Osname), + Version: ptr.Of(bouncer.Osversion), + } + rcMetrics.Type = bouncer.Type + rcMetrics.FeatureFlags = strings.Split(bouncer.Featureflags, ",") + rcMetrics.Version = ptr.Of(bouncer.Version) + rcMetrics.Name = bouncer.Name + + rcMetrics.LastPull = 0 + if bouncer.LastPull != nil { + rcMetrics.LastPull = bouncer.LastPull.UTC().Unix() + } + + rcMetrics.Metrics = make([]*models.DetailedMetrics, 0) + + // Might seem weird, but we duplicate the bouncers if we have multiple unsent metrics + for _, dbMetric := range dbMetrics { + dbPayload := &dbPayload{} + // Append no matter what, if we cannot unmarshal, there's no way we'll be able to fix it automatically + metricsIds = append(metricsIds, dbMetric.ID) + + err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) + if err != nil { + log.Errorf("unable to parse bouncer metric (%s)", err) + continue + } + + rcMetrics.Metrics = append(rcMetrics.Metrics, dbPayload.Metrics...) + } + + allMetrics.RemediationComponents = append(allMetrics.RemediationComponents, &rcMetrics) + } + + for _, lp := range lps { + dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(ctx, lp.MachineId) + if err != nil { + log.Errorf("unable to get LP usage metrics: %s", err) + continue + } + + lpMetrics := models.LogProcessorsMetrics{} + + lpMetrics.Os = &models.OSversion{ + Name: ptr.Of(lp.Osname), + Version: ptr.Of(lp.Osversion), + } + lpMetrics.FeatureFlags = strings.Split(lp.Featureflags, ",") + lpMetrics.Version = ptr.Of(lp.Version) + lpMetrics.Name = lp.MachineId + + lpMetrics.LastPush = 0 + if lp.LastPush != nil { + lpMetrics.LastPush = lp.LastPush.UTC().Unix() + } + + lpMetrics.LastUpdate = lp.UpdatedAt.UTC().Unix() + lpMetrics.Datasources = lp.Datasources + + hubItems := models.HubItems{} + + if lp.Hubstate != nil { + // must carry over the hub state even if nothing is installed + for itemType, items := range lp.Hubstate { + hubItems[itemType] = []models.HubItem{} + for _, item := range items { + hubItems[itemType] = append(hubItems[itemType], models.HubItem{ + Name: item.Name, + Status: item.Status, + Version: item.Version, + }) + } + } + } + + lpMetrics.HubItems = hubItems + + lpMetrics.Metrics = make([]*models.DetailedMetrics, 0) + + for _, dbMetric := range dbMetrics { + dbPayload := &dbPayload{} + // Append no matter what, if we cannot unmarshal, there's no way we'll be able to fix it automatically + metricsIds = append(metricsIds, dbMetric.ID) + + err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) + if err != nil { + log.Errorf("unable to parse log processor metric (%s)", err) + continue + } + + lpMetrics.Metrics = append(lpMetrics.Metrics, dbPayload.Metrics...) + } + + allMetrics.LogProcessors = append(allMetrics.LogProcessors, &lpMetrics) + } + + // FIXME: all of this should only be done once on startup/reload + consoleOptions := strings.Join(csconfig.GetConfig().API.Server.ConsoleConfig.EnabledOptions(), ",") + allMetrics.Lapi = &models.LapiMetrics{ + ConsoleOptions: models.ConsoleOptions{ + consoleOptions, + }, + } + + osName, osVersion := version.DetectOS() + + allMetrics.Lapi.Os = &models.OSversion{ + Name: ptr.Of(osName), + Version: ptr.Of(osVersion), + } + allMetrics.Lapi.Version = ptr.Of(version.String()) + allMetrics.Lapi.FeatureFlags = fflag.Crowdsec.GetEnabledFeatures() + + allMetrics.Lapi.Metrics = make([]*models.DetailedMetrics, 0) + + allMetrics.Lapi.Metrics = append(allMetrics.Lapi.Metrics, &models.DetailedMetrics{ + Meta: &models.MetricsMeta{ + UtcNowTimestamp: ptr.Of(time.Now().UTC().Unix()), + WindowSizeSeconds: ptr.Of(int64(a.metricsInterval.Seconds())), + }, + Items: make([]*models.MetricsDetailItem, 0), + }) + + // Force an actual slice to avoid non existing fields in the json + if allMetrics.RemediationComponents == nil { + allMetrics.RemediationComponents = make([]*models.RemediationComponentsMetrics, 0) + } + + if allMetrics.LogProcessors == nil { + allMetrics.LogProcessors = make([]*models.LogProcessorsMetrics, 0) + } + + return allMetrics, metricsIds, nil +} + +func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + return a.dbClient.MarkUsageMetricsAsSent(ctx, ids) +} + +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -31,7 +202,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -39,11 +210,16 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { bouncersInfo := make([]*models.MetricsBouncerInfo, len(bouncers)) for i, bouncer := range bouncers { + lastPull := "" + if bouncer.LastPull != nil { + lastPull = bouncer.LastPull.Format(time.RFC3339) + } + bouncersInfo[i] = &models.MetricsBouncerInfo{ Version: bouncer.Version, CustomName: bouncer.Name, Name: bouncer.Type, - LastPull: bouncer.LastPull.Format(time.RFC3339), + LastPull: lastPull, } } @@ -54,8 +230,8 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -75,7 +251,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) { // Metrics are sent at start, then at the randomized metricsIntervalFirst, // then at regular metricsInterval. If a change is detected in the list // of machines, the next metrics are sent immediately. -func (a *apic) SendMetrics(stop chan (bool)) { +func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") // verify the list of machines every interval @@ -99,7 +275,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) @@ -135,7 +311,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } @@ -143,7 +319,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { if metrics != nil { log.Info("capi metrics: sending") - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) + _, _, err = a.apiClient.Metrics.Add(ctx, metrics) if err != nil { log.Errorf("capi metrics: failed: %s", err) } @@ -160,3 +336,52 @@ func (a *apic) SendMetrics(stop chan (bool)) { } } } + +func (a *apic) SendUsageMetrics(ctx context.Context) { + defer trace.CatchPanic("lapi/usageMetricsToAPIC") + + firstRun := true + + log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) + ticker := time.NewTicker(a.usageMetricsIntervalFirst) + + for { + select { + case <-a.metricsTomb.Dying(): + // The normal metrics routine also kills push/pull tombs, does that make sense ? + ticker.Stop() + return + case <-ticker.C: + if firstRun { + firstRun = false + + ticker.Reset(a.usageMetricsInterval) + } + + metrics, metricsId, err := a.GetUsageMetrics(ctx) + if err != nil { + log.Errorf("unable to get usage metrics: %s", err) + continue + } + + _, resp, err := a.apiClient.UsageMetrics.Add(ctx, metrics) + if err != nil { + log.Errorf("unable to send usage metrics: %s", err) + + if resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity { + // In case of 422, mark the metrics as sent anyway, the API did not like what we sent, + // and it's unlikely we'll be able to fix it + continue + } + } + + err = a.MarkUsageMetricsAsSent(ctx, metricsId) + if err != nil { + log.Errorf("unable to mark usage metrics as sent: %s", err) + continue + } + + log.Infof("Sent %d usage metrics", len(metricsId)) + } + } +} diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index 2bc0dd26966..d81af03f710 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -2,7 +2,6 @@ package apiserver import ( "context" - "fmt" "net/url" "testing" "time" @@ -11,12 +10,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/go-cs-lib/version" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" ) func TestAPICSendMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string duration time.Duration @@ -26,18 +25,18 @@ func TestAPICSendMetrics(t *testing.T) { }{ { name: "basic", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) {}, }, { name: "with some metrics", - duration: time.Millisecond * 60, - metricsInterval: time.Millisecond * 10, + duration: time.Millisecond * 120, + metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) { - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) api.dbClient.Ent.Machine.Create(). SetMachineId("1234"). SetPassword(testPassword.String()). @@ -45,16 +44,16 @@ func TestAPICSendMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) api.dbClient.Ent.Bouncer.Create(). SetIPAddress("1.2.3.6"). SetName("someBouncer"). SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) }, }, } @@ -65,7 +64,6 @@ func TestAPICSendMetrics(t *testing.T) { defer httpmock.Deactivate() for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -73,12 +71,12 @@ func TestAPICSendMetrics(t *testing.T) { apiClient, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond api.apiClient = apiClient @@ -87,8 +85,11 @@ func TestAPICSendMetrics(t *testing.T) { tc.setUp(api) stop := make(chan bool) + httpmock.ZeroCallCounters() - go api.SendMetrics(stop) + + go api.SendMetrics(ctx, stop) + time.Sleep(tc.duration) stop <- true diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 74c627cd020..b52dc9e44cc 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -34,12 +34,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func getDBClient(t *testing.T) *database.Client { +func getDBClient(t *testing.T, ctx context.Context) *database.Client { t.Helper() dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), @@ -49,13 +49,13 @@ func getDBClient(t *testing.T) *database.Client { return dbClient } -func getAPIC(t *testing.T) *apic { +func getAPIC(t *testing.T, ctx context.Context) *apic { t.Helper() - dbClient := getDBClient(t) + dbClient := getDBClient(t, ctx) return &apic{ AlertsAddChan: make(chan []*models.Alert), - //DecisionDeleteChan: make(chan []*models.Decision), + // DecisionDeleteChan: make(chan []*models.Decision), dbClient: dbClient, mu: sync.Mutex{}, startup: true, @@ -82,8 +82,8 @@ func absDiff(a int, b int) int { return c } -func assertTotalDecisionCount(t *testing.T, dbClient *database.Client, count int) { - d := dbClient.Ent.Decision.Query().AllX(context.Background()) +func assertTotalDecisionCount(t *testing.T, ctx context.Context, dbClient *database.Client, count int) { + d := dbClient.Ent.Decision.Query().AllX(ctx) assert.Len(t, d, count) } @@ -109,9 +109,10 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { } func TestAPICCAPIPullIsOld(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) - isOld, err := api.CAPIPullIsOld() + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -122,7 +123,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -130,15 +131,17 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) } func TestAPICFetchScenariosListFromDB(t *testing.T) { + ctx := context.Background() + tests := []struct { name string machineIDsWithScenarios map[string]string @@ -162,23 +165,23 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) for machineID, scenarios := range tc.machineIDsWithScenarios { api.dbClient.Ent.Machine.Create(). SetMachineId(machineID). SetPassword(testPassword.String()). SetIpAddress("1.2.3.4"). SetScenarios(scenarios). - ExecX(context.Background()) + ExecX(ctx) } - scenarios, err := api.FetchScenariosListFromDB() + scenarios, err := api.FetchScenariosListFromDB(ctx) + require.NoError(t, err) + for machineID := range tc.machineIDsWithScenarios { - api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(ctx) } - require.NoError(t, err) assert.ElementsMatch(t, tc.expectedScenarios, scenarios) }) @@ -186,6 +189,8 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { } func TestNewAPIC(t *testing.T) { + ctx := context.Background() + var testConfig *csconfig.OnlineApiClientCfg setConfig := func() { @@ -213,7 +218,7 @@ func TestNewAPIC(t *testing.T) { name: "simple", action: func() {}, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, }, @@ -221,7 +226,7 @@ func TestNewAPIC(t *testing.T) { name: "error in parsing URL", action: func() { testConfig.Credentials.URL = "foobar http://" }, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, expectedErr: "first path segment in URL cannot contain colon", @@ -229,10 +234,10 @@ func TestNewAPIC(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { setConfig() httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -244,14 +249,15 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func TestAPICHandleDeletedDecisions(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) _, deleteCounters := makeAddAndDeleteCounters() decision1 := api.dbClient.Ent.Decision.Create(). @@ -272,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { SetOrigin(types.CAPIOrigin). SaveX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 2) + assertTotalDecisionCount(t, ctx, api.dbClient, 2) nbDeleted, err := api.HandleDeletedDecisions([]*models.Decision{{ Value: ptr.Of("1.2.3.4"), @@ -288,9 +294,11 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { } func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -348,10 +356,10 @@ func TestAPICGetMetrics(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - apiClient := getAPIC(t) + apiClient := getAPIC(t, ctx) cleanUp(apiClient) + for i, machineID := range tc.machineIDs { apiClient.dbClient.Ent.Machine.Create(). SetMachineId(machineID). @@ -360,7 +368,7 @@ func TestAPICGetMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } for i, bouncerName := range tc.bouncers { @@ -370,10 +378,10 @@ func TestAPICGetMetrics(t *testing.T) { SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) @@ -455,7 +463,6 @@ func TestCreateAlertsForDecision(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) { t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want) @@ -535,7 +542,6 @@ func TestFillAlertsWithDecisions(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { addCounters, _ := makeAddAndDeleteCounters() if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) { @@ -546,8 +552,9 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { - api := getAPIC(t) - //one whitelist on IP, one on CIDR + ctx := context.Background() + api := getAPIC(t, ctx) + // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) @@ -569,7 +576,7 @@ func TestAPICWhitelists(t *testing.T) { SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() @@ -592,7 +599,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("13.2.3.4"), //wl by cidr + Value: ptr.Of("13.2.3.4"), // wl by cidr Duration: ptr.Of("24h"), }, }, @@ -613,7 +620,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("13.2.3.5"), //wl by cidr + Value: ptr.Of("13.2.3.5"), // wl by cidr Duration: ptr.Of("24h"), }, }, @@ -633,7 +640,7 @@ func TestAPICWhitelists(t *testing.T) { Scope: ptr.Of("Ip"), Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Value: ptr.Of("9.2.3.4"), //wl by ip + Value: ptr.Of("9.2.3.4"), // wl by ip Duration: ptr.Of("24h"), }, }, @@ -675,16 +682,16 @@ func TestAPICWhitelists(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing + assertTotalDecisionCount(t, ctx, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) @@ -732,7 +739,8 @@ func TestAPICWhitelists(t *testing.T) { } func TestAPICPullTop(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -740,8 +748,8 @@ func TestAPICPullTop(t *testing.T) { SetScope("Ip"). SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). - ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + ExecX(ctx) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() @@ -816,23 +824,22 @@ func TestAPICPullTop(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) + assertTotalDecisionCount(t, ctx, api.dbClient, 5) assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) validDecisions := api.dbClient.Ent.Decision.Query().Where( decision.UntilGT(time.Now())). - AllX(context.Background(), - ) + AllX(context.Background()) decisionScenarioFreq := make(map[string]int) alertScenario := make(map[string]int) @@ -857,8 +864,9 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. - api := getAPIC(t) + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -904,17 +912,17 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -924,15 +932,16 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -996,18 +1005,19 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) } func TestAPICPullBlocklistCall(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -1023,13 +1033,13 @@ func TestAPICPullBlocklistCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullBlocklist(&modelscapi.BlocklistLink{ + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ URL: ptr.Of("http://api.crowdsec.net/blocklist1"), Name: ptr.Of("blocklist1"), Scope: ptr.Of("Ip"), @@ -1040,6 +1050,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { } func TestAPICPush(t *testing.T) { + ctx := context.Background() tests := []struct { name string alerts []*models.Alert @@ -1076,7 +1087,7 @@ func TestAPICPush(t *testing.T) { expectedCalls: 2, alerts: func() []*models.Alert { alerts := make([]*models.Alert, 100) - for i := 0; i < 100; i++ { + for i := range 100 { alerts[i] = &models.Alert{ Scenario: ptr.Of("crowdsec/test"), ScenarioHash: ptr.Of("certified"), @@ -1092,9 +1103,8 @@ func TestAPICPush(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") @@ -1102,22 +1112,29 @@ func TestAPICPush(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) api.apiClient = apic + httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{})) + + // capture the alerts to avoid datarace + alerts := tc.alerts go func() { - api.AlertsAddChan <- tc.alerts + api.AlertsAddChan <- alerts + time.Sleep(time.Second) api.Shutdown() }() - err = api.Push() + + err = api.Push(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount()) }) @@ -1125,7 +1142,8 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) tests := []struct { name string setUp func() @@ -1152,23 +1170,26 @@ func TestAPICPull(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api = getAPIC(t) + api = getAPIC(t, ctx) api.pullInterval = time.Millisecond api.pullIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) httpmock.Activate() + defer httpmock.DeactivateAndReset() + apic, err := apiclient.NewDefaultClient( url, "/api", - fmt.Sprintf("crowdsec/%s", version.String()), + "", nil, ) require.NoError(t, err) + api.apiClient = apic + httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX( modelscapi.GetDecisionsStreamResponse{ New: modelscapi.GetDecisionsStreamResponseNew{ @@ -1186,18 +1207,22 @@ func TestAPICPull(t *testing.T) { }, ))) tc.setUp() + var buf bytes.Buffer + go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + + if err := api.Pull(ctx); err != nil { panic(err) } }() - //Slightly long because the CI runner for windows are slow, and this can lead to random failure + + // Slightly long because the CI runner for windows are slow, and this can lead to random failure time.Sleep(time.Millisecond * 500) logrus.SetOutput(os.Stderr) assert.Contains(t, buf.String(), tc.logContains) - assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount) + assertTotalDecisionCount(t, ctx, api.dbClient, tc.expectedDecisionCount) }) } } @@ -1279,7 +1304,6 @@ func TestShouldShareAlert(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { ret := shouldShareAlert(tc.alert, tc.consoleConfig) assert.Equal(t, tc.expectedRet, ret) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 19a0085d2dc..35f9beaf635 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -32,6 +32,7 @@ const keyLength = 32 type APIServer struct { URL string + UnixSocket string TLS *csconfig.TLSCfg dbClient *database.Client logFile string @@ -66,7 +67,7 @@ func recoverFromPanic(c *gin.Context) { // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them if strErr, ok := err.(error); ok { - //stolen from http2/server.go in x/net + // stolen from http2/server.go in x/net var ( errClientDisconnected = errors.New("client disconnected") errClosedBody = errors.New("body closed by handler") @@ -83,11 +84,16 @@ func recoverFromPanic(c *gin.Context) { } if brokenPipe { - log.Warningf("client %s disconnected : %s", c.ClientIP(), err) + log.Warningf("client %s disconnected: %s", c.ClientIP(), err) c.Abort() } else { - filename := trace.WriteStackTrace(err) - log.Warningf("client %s error : %s", c.ClientIP(), err) + log.Warningf("client %s error: %s", c.ClientIP(), err) + + filename, err := trace.WriteStackTrace(err) + if err != nil { + log.Errorf("also while writing stacktrace: %s", err) + } + log.Warningf("stacktrace written to %s, please join to your issue", filename) c.AbortWithStatus(http.StatusInternalServerError) } @@ -124,10 +130,10 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro logger := &lumberjack.Logger{ Filename: logFile, - MaxSize: 500, //megabytes + MaxSize: 500, // megabytes MaxBackups: 3, - MaxAge: 28, //days - Compress: true, //disabled by default + MaxAge: 28, // days + Compress: true, // disabled by default } if config.LogMaxSize != 0 { @@ -153,16 +159,16 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro // NewServer creates a LAPI server. // It sets up a gin router, a database client, and a controller. -func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { +func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler - dbClient, err := database.NewClient(config.DbConfig) + dbClient, err := database.NewClient(ctx, config.DbConfig) if err != nil { return nil, fmt.Errorf("unable to init database client: %w", err) } if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) if err != nil { return nil, err } @@ -176,6 +182,13 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { router.ForwardedByClientIP = false + // set the remore address of the request to 127.0.0.1 if it comes from a unix socket + router.Use(func(c *gin.Context) { + if c.Request.RemoteAddr == "@" { + c.Request.RemoteAddr = "127.0.0.1:65535" + } + }) + if config.TrustedProxies != nil && config.UseForwardedForHeaders { if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil { return nil, fmt.Errorf("while setting trusted_proxies: %w", err) @@ -214,17 +227,17 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { controller := &controllers.Controller{ DBClient: dbClient, - Ectx: context.Background(), Router: router, Profiles: config.Profiles, Log: clog, ConsoleConfig: config.ConsoleConfig, DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration, + AutoRegisterCfg: config.AutoRegister, } var ( - apiClient *apic - papiClient *Papi + apiClient *apic + papiClient *Papi ) controller.AlertsAddChan = nil @@ -233,7 +246,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err } @@ -242,8 +255,8 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { controller.AlertsAddChan = apiClient.AlertsAddChan - if apiClient.apiClient.IsEnrolled() { - if config.ConsoleConfig.IsPAPIEnabled() { + if config.ConsoleConfig.IsPAPIEnabled() && config.OnlineClient.Credentials.PapiURL != "" { + if apiClient.apiClient.IsEnrolled() { log.Info("Machine is enrolled in the console, Loading PAPI Client") papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) @@ -252,9 +265,9 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { } controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel + } else { + log.Error("Machine is not enrolled in the console, can't synchronize with the console") } - } else { - log.Errorf("Machine is not enrolled in the console, can't synchronize with the console") } } @@ -267,6 +280,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { return &APIServer{ URL: config.ListenURI, + UnixSocket: config.ListenSocket, TLS: config.TLS, logFile: logFile, dbClient: dbClient, @@ -284,6 +298,72 @@ func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } +func (s *APIServer) apicPush(ctx context.Context) error { + if err := s.apic.Push(ctx); err != nil { + log.Errorf("capi push: %s", err) + return err + } + + return nil +} + +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { + log.Errorf("capi pull: %s", err) + return err + } + + return nil +} + +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { + log.Errorf("papi pull: %s", err) + return err + } + + return nil +} + +func (s *APIServer) papiSync() error { + if err := s.papi.SyncDecisions(); err != nil { + log.Errorf("capi decisions sync: %s", err) + return err + } + + return nil +} + +func (s *APIServer) initAPIC(ctx context.Context) { + s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) }) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) + + // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios + if s.apic.apiClient.IsEnrolled() { + if s.consoleConfig.IsPAPIEnabled() && s.papi != nil { + if s.papi.URL != "" { + log.Info("Starting PAPI decision receiver") + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) + s.papi.syncTomb.Go(s.papiSync) + } else { + log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") + } + } else { + log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") + } + } + + s.apic.metricsTomb.Go(func() error { + s.apic.SendMetrics(ctx, make(chan bool)) + return nil + }) + + s.apic.metricsTomb.Go(func() error { + s.apic.SendUsageMetrics(ctx) + return nil + }) +} + func (s *APIServer) Run(apiReady chan bool) error { defer trace.CatchPanic("lapi/runServer") @@ -298,84 +378,37 @@ func (s *APIServer) Run(apiReady chan bool) error { TLSConfig: tlsCfg, } - if s.apic != nil { - s.apic.pushTomb.Go(func() error { - if err := s.apic.Push(); err != nil { - log.Errorf("capi push: %s", err) - return err - } - - return nil - }) + ctx := context.TODO() - s.apic.pullTomb.Go(func() error { - if err := s.apic.Pull(); err != nil { - log.Errorf("capi pull: %s", err) - return err - } + if s.apic != nil { + s.initAPIC(ctx) + } - return nil - }) - - //csConfig.API.Server.ConsoleConfig.ShareCustomScenarios - if s.apic.apiClient.IsEnrolled() { - if s.consoleConfig.IsPAPIEnabled() { - if s.papi.URL != "" { - log.Infof("Starting PAPI decision receiver") - s.papi.pullTomb.Go(func() error { - if err := s.papi.Pull(); err != nil { - log.Errorf("papi pull: %s", err) - return err - } - - return nil - }) - - s.papi.syncTomb.Go(func() error { - if err := s.papi.SyncDecisions(); err != nil { - log.Errorf("capi decisions sync: %s", err) - return err - } - - return nil - }) - } else { - log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") - } - } else { - log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") - } - } + s.httpServerTomb.Go(func() error { + return s.listenAndServeLAPI(apiReady) + }) - s.apic.metricsTomb.Go(func() error { - s.apic.SendMetrics(make(chan bool)) - return nil - }) + if err := s.httpServerTomb.Wait(); err != nil { + return fmt.Errorf("local API server stopped with error: %w", err) } - s.httpServerTomb.Go(func() error { s.listenAndServeURL(apiReady); return nil }) - return nil } -// listenAndServeURL starts the http server and blocks until it's closed +// listenAndServeLAPI starts the http server and blocks until it's closed // it also updates the URL field with the actual address the server is listening on // it's meant to be run in a separate goroutine -func (s *APIServer) listenAndServeURL(apiReady chan bool) { - serverError := make(chan error, 1) - - go func() { - listener, err := net.Listen("tcp", s.URL) - if err != nil { - serverError <- fmt.Errorf("listening on %s: %w", s.URL, err) - return - } - - s.URL = listener.Addr().String() - log.Infof("CrowdSec Local API listening on %s", s.URL) - apiReady <- true +func (s *APIServer) listenAndServeLAPI(apiReady chan bool) error { + var ( + tcpListener net.Listener + unixListener net.Listener + err error + serverError = make(chan error, 2) + listenerClosed = make(chan struct{}) + ) - if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { + startServer := func(listener net.Listener, canTLS bool) { + if canTLS && s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { if s.TLS.KeyFilePath == "" { serverError <- errors.New("missing TLS key file") return @@ -391,25 +424,71 @@ func (s *APIServer) listenAndServeURL(apiReady chan bool) { err = s.httpServer.Serve(listener) } - if err != nil && err != http.ErrServerClosed { - serverError <- fmt.Errorf("while serving local API: %w", err) + switch { + case errors.Is(err, http.ErrServerClosed): + break + case err != nil: + serverError <- err + } + } + + // Starting TCP listener + go func() { + if s.URL == "" { + return + } + + tcpListener, err = net.Listen("tcp", s.URL) + if err != nil { + serverError <- fmt.Errorf("listening on %s: %w", s.URL, err) return } + + log.Infof("CrowdSec Local API listening on %s", s.URL) + startServer(tcpListener, true) }() + // Starting Unix socket listener + go func() { + if s.UnixSocket == "" { + return + } + + _ = os.RemoveAll(s.UnixSocket) + + unixListener, err = net.Listen("unix", s.UnixSocket) + if err != nil { + serverError <- fmt.Errorf("while creating unix listener: %w", err) + return + } + + log.Infof("CrowdSec Local API listening on Unix socket %s", s.UnixSocket) + startServer(unixListener, false) + }() + + apiReady <- true + select { case err := <-serverError: - log.Fatalf("while starting API server: %s", err) + return err case <-s.httpServerTomb.Dying(): - log.Infof("Shutting down API server") - // do we need a graceful shutdown here? + log.Info("Shutting down API server") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.httpServer.Shutdown(ctx); err != nil { - log.Errorf("while shutting down http server: %s", err) + log.Errorf("while shutting down http server: %v", err) + } + + close(listenerClosed) + case <-listenerClosed: + if s.UnixSocket != "" { + _ = os.RemoveAll(s.UnixSocket) } } + + return nil } func (s *APIServer) Close() { @@ -437,7 +516,7 @@ func (s *APIServer) Shutdown() error { } } - //close io.writer logger given to gin + // close io.writer logger given to gin if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { pipe.Close() } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index b7f6be5fe36..cdf99462c35 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -1,8 +1,8 @@ package apiserver import ( + "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "os" @@ -28,15 +28,21 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var testMachineID = "test" -var testPassword = strfmt.Password("test") -var MachineTest = models.WatcherAuthRequest{ - MachineID: &testMachineID, - Password: &testPassword, -} +const ( + validRegistrationToken = "igheethauCaeteSaiyee3LosohPhahze" + invalidRegistrationToken = "vohl1feibechieG5coh8musheish2auj" +) -var UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version) -var emptyBody = strings.NewReader("") +var ( + testMachineID = "test" + testPassword = strfmt.Password("test") + MachineTest = models.WatcherRegistrationRequest{ + MachineID: &testMachineID, + Password: &testPassword, + } + UserAgent = "crowdsec-test/" + version.Version + emptyBody = strings.NewReader("") +) func LoadTestConfig(t *testing.T) csconfig.Config { config := csconfig.Config{} @@ -63,6 +69,14 @@ func LoadTestConfig(t *testing.T) csconfig.Config { ShareTaintedScenarios: new(bool), ShareCustomScenarios: new(bool), }, + AutoRegister: &csconfig.LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(true), + Token: validRegistrationToken, + AllowedRanges: []string{ + "127.0.0.1/8", + "::1/128", + }, + }, } apiConfig := csconfig.APICfg{ @@ -73,6 +87,9 @@ func LoadTestConfig(t *testing.T) csconfig.Config { err := config.API.Server.LoadProfiles() require.NoError(t, err) + err = config.API.Server.LoadAutoRegister() + require.NoError(t, err) + return config } @@ -111,15 +128,18 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { err := config.API.Server.LoadProfiles() require.NoError(t, err) + err = config.API.Server.LoadAutoRegister() + require.NoError(t, err) + return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { +func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) log.Printf("Creating new API server") @@ -128,8 +148,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { - apiServer, config := NewAPIServer(t) +func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t, ctx) err := apiServer.InitController() require.NoError(t, err) @@ -140,12 +160,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { +func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) err = apiServer.InitController() @@ -160,19 +180,21 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { - dbClient, err := database.NewClient(config) +func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) { + dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - err = dbClient.ValidateMachine(machineID) + err = dbClient.ValidateMachine(ctx, machineID) require.NoError(t, err) } func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string { - dbClient, err := database.NewClient(config) + ctx := context.Background() + + dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - machines, err := dbClient.ListMachines() + machines, err := dbClient.ListMachines(ctx) require.NoError(t, err) for _, machine := range machines { @@ -245,61 +267,66 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map return response, resp.Code } -func CreateTestMachine(t *testing.T, router *gin.Engine) string { - b, err := json.Marshal(MachineTest) +func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string { + regReq := MachineTest + regReq.RegistrationToken = token + b, err := json.Marshal(regReq) require.NoError(t, err) body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) return body } -func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { - dbClient, err := database.NewClient(config) +func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string { + dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) require.NoError(t, err) return apiKey } func TestWithWrongDBConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) config.API.Server.DbConfig.Type = "test" - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) maxItems := -1 config.API.Server.DbConfig.Flush.MaxItems = &maxItems - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) - cstest.RequireErrorContains(t, err, "max_items can't be zero or negative number") + cstest.RequireErrorContains(t, err, "max_items can't be zero or negative") assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) } /* @@ -318,6 +345,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0 */ func TestLoggingDebugToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -339,7 +368,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") expectedLines := []string{"/test42"} cfg.LogLevel = ptr.Of(log.DebugLevel) @@ -347,19 +376,19 @@ func TestLoggingDebugToFileConfig(t *testing.T) { err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) - //wait for the request to happen + assert.Equal(t, http.StatusNotFound, w.Code) + // wait for the request to happen time.Sleep(500 * time.Millisecond) - //check file content + // check file content data, err := os.ReadFile(expectedFile) require.NoError(t, err) @@ -369,6 +398,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } func TestLoggingErrorToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -390,26 +421,26 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) - //wait for the request to happen + // wait for the request to happen time.Sleep(500 * time.Millisecond) - //check file content + // check file content x, err := os.ReadFile(expectedFile) if err == nil { require.Empty(t, x) diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index bab1965123e..719bb231006 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -1,9 +1,9 @@ package controllers import ( - "context" "net" "net/http" + "strings" "github.com/alexliesenfeld/health" "github.com/gin-gonic/gin" @@ -17,7 +17,6 @@ import ( ) type Controller struct { - Ectx context.Context DBClient *database.Client Router *gin.Engine Profiles []*csconfig.ProfileCfg @@ -28,6 +27,7 @@ type Controller struct { ConsoleConfig *csconfig.ConsoleConfig TrustedIPs []net.IPNet HandlerV1 *v1.Controller + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg DisableRemoteLapiRegistration bool } @@ -59,18 +59,35 @@ func serveHealth() http.HandlerFunc { return health.NewHandler(checker) } +func eitherAuthMiddleware(jwtMiddleware gin.HandlerFunc, apiKeyMiddleware gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + switch { + case c.GetHeader("X-Api-Key") != "": + apiKeyMiddleware(c) + case c.GetHeader("Authorization") != "": + jwtMiddleware(c) + // uh no auth header. is this TLS with mutual authentication? + case strings.HasPrefix(c.Request.UserAgent(), "crowdsec/"): + // guess log processors by sniffing user-agent + jwtMiddleware(c) + default: + apiKeyMiddleware(c) + } + } +} + func (c *Controller) NewV1() error { var err error v1Config := v1.ControllerV1Config{ DbClient: c.DBClient, - Ctx: c.Ectx, ProfilesCfg: c.Profiles, DecisionDeleteChan: c.DecisionDeleteChan, AlertsAddChan: c.AlertsAddChan, PluginChannel: c.PluginChannel, ConsoleConfig: *c.ConsoleConfig, TrustedIPs: c.TrustedIPs, + AutoRegisterCfg: c.AutoRegisterCfg, } c.HandlerV1, err = v1.New(&v1Config) @@ -117,6 +134,12 @@ func (c *Controller) NewV1() error { apiKeyAuth.HEAD("/decisions/stream", c.HandlerV1.StreamDecision) } + eitherAuth := groupV1.Group("") + eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc())) + { + eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics) + } + return nil } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index e7d106d72a3..d1f93228512 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -6,10 +6,8 @@ import ( "net" "net/http" "strconv" - "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/google/uuid" @@ -44,6 +42,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { Capacity: &alert.Capacity, Leakspeed: &alert.LeakSpeed, Simulated: &alert.Simulated, + Remediation: alert.Remediation, UUID: alert.UUID, Source: &models.Source{ Scope: &alert.SourceScope, @@ -64,7 +63,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { var Metas models.Meta if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { - log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) + log.Errorf("unable to parse events meta '%s' : %s", eventItem.Serialized, err) } outputAlert.Events = append(outputAlert.Events, &models.Event{ @@ -81,7 +80,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { } for _, decisionItem := range alert.Edges.Decisions { - duration := decisionItem.Until.Sub(time.Now().UTC()).String() + duration := decisionItem.Until.Sub(time.Now().UTC()).Round(time.Second).String() outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{ Duration: &duration, // transform into time.Time ? Scenario: &decisionItem.Scenario, @@ -110,7 +109,7 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest { func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uint) { if c.PluginChannel != nil { RETRY: - for try := 0; try < 3; try++ { + for try := range 3 { select { case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}: log.Debugf("alert sent to Plugin channel") @@ -124,28 +123,12 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin } } -func normalizeScope(scope string) string { - switch strings.ToLower(scope) { - case "ip": - return types.Ip - case "range": - return types.Range - case "as": - return types.AS - case "country": - return types.Country - default: - return scope - } -} - // CreateAlert writes the alerts received in the body to the database func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + ctx := gctx.Request.Context() + machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) @@ -162,12 +145,12 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { for _, alert := range input { // normalize scope for alert.Source and decisions if alert.Source.Scope != nil { - *alert.Source.Scope = normalizeScope(*alert.Source.Scope) + *alert.Source.Scope = types.NormalizeScope(*alert.Source.Scope) } for _, decision := range alert.Decisions { if decision.Scope != nil { - *decision.Scope = normalizeScope(*decision.Scope) + *decision.Scope = types.NormalizeScope(*decision.Scope) } } @@ -177,7 +160,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // if coming from cscli, alert already has decisions if len(alert.Decisions) != 0 { - //alert already has a decision (cscli decisions add etc.), generate uuid here + // alert already has a decision (cscli decisions add etc.), generate uuid here for _, decision := range alert.Decisions { decision.UUID = uuid.NewString() } @@ -257,7 +240,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(machineID, input) + alerts, err := c.DBClient.CreateAlert(ctx, machineID, input) c.DBClient.CanFlush = true if err != nil { @@ -279,7 +262,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return @@ -297,15 +282,16 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { // FindAlertByID returns the alert associated with the ID func (c *Controller) FindAlertByID(gctx *gin.Context) { + ctx := gctx.Request.Context() alertIDStr := gctx.Param("alert_id") - alertID, err := strconv.Atoi(alertIDStr) + alertID, err := strconv.Atoi(alertIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } - result, err := c.DBClient.GetAlertByID(alertID) + result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -325,20 +311,23 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() - if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { + if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } decisionIDStr := gctx.Param("alert_id") + decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } - err = c.DBClient.DeleteAlertByID(decisionID) + err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -351,13 +340,15 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { // DeleteAlerts deletes alerts from the database based on the specified filter func (c *Controller) DeleteAlerts(gctx *gin.Context) { + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() - if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { + if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index ad76ad76616..f8b6aa76ea5 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -1,7 +1,6 @@ package v1 import ( - "context" "fmt" "net" @@ -14,7 +13,6 @@ import ( ) type Controller struct { - Ectx context.Context DBClient *database.Client APIKeyHeader string Middlewares *middlewares.Middlewares @@ -23,22 +21,23 @@ type Controller struct { AlertsAddChan chan []*models.Alert DecisionDeleteChan chan []*models.Decision - PluginChannel chan csplugin.ProfileAlert - ConsoleConfig csconfig.ConsoleConfig - TrustedIPs []net.IPNet + PluginChannel chan csplugin.ProfileAlert + ConsoleConfig csconfig.ConsoleConfig + TrustedIPs []net.IPNet + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg } type ControllerV1Config struct { DbClient *database.Client - Ctx context.Context ProfilesCfg []*csconfig.ProfileCfg AlertsAddChan chan []*models.Alert DecisionDeleteChan chan []*models.Decision - PluginChannel chan csplugin.ProfileAlert - ConsoleConfig csconfig.ConsoleConfig - TrustedIPs []net.IPNet + PluginChannel chan csplugin.ProfileAlert + ConsoleConfig csconfig.ConsoleConfig + TrustedIPs []net.IPNet + AutoRegisterCfg *csconfig.LocalAPIAutoRegisterCfg } func New(cfg *ControllerV1Config) (*Controller, error) { @@ -50,7 +49,6 @@ func New(cfg *ControllerV1Config) (*Controller, error) { } v1 := &Controller{ - Ectx: cfg.Ctx, DBClient: cfg.DbClient, APIKeyHeader: middlewares.APIKeyHeader, Profiles: profiles, @@ -59,9 +57,10 @@ func New(cfg *ControllerV1Config) (*Controller, error) { PluginChannel: cfg.PluginChannel, ConsoleConfig: cfg.ConsoleConfig, TrustedIPs: cfg.TrustedIPs, + AutoRegisterCfg: cfg.AutoRegisterCfg, } - v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) + v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) if err != nil { return v1, err } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index f3c6a7bba26..ffefffc226b 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,8 +1,8 @@ package v1 import ( + "context" "encoding/json" - "fmt" "net/http" "strconv" "time" @@ -20,7 +20,7 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { var results []*models.Decision for _, dbDecision := range decisions { - duration := dbDecision.Until.Sub(time.Now().UTC()).String() + duration := dbDecision.Until.Sub(time.Now().UTC()).Round(time.Second).String() decision := models.Decision{ ID: int64(dbDecision.ID), Duration: &duration, @@ -43,6 +43,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) { data []*ent.Decision ) + ctx := gctx.Request.Context() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) @@ -50,7 +52,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { return } - data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -72,8 +74,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) { return } - if time.Now().UTC().Sub(bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute { + if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -91,7 +93,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } - nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -113,7 +117,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionsWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -134,33 +140,38 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { gctx.JSON(http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(filters) + data, err := dbFunc(ctx, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -172,10 +183,12 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() @@ -186,33 +199,38 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull time.Time, dbFunc func(time.Time, map[string][]string) ([]*ent.Decision, error)) error { - //respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable +func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { + // respBuffer := bytes.NewBuffer([]byte{}) + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(lastPull, filters) + data, err := dbFunc(ctx, lastPull, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -224,10 +242,12 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() @@ -244,7 +264,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B gctx.Writer.Header().Set("Content-Type", "application/json") gctx.Writer.Header().Set("Transfer-Encoding", "chunked") gctx.Writer.WriteHeader(http.StatusOK) - gctx.Writer.Write([]byte(`{"new": [`)) //No need to check for errors, the doc says it always returns nil + gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { @@ -252,48 +272,47 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) if err != nil { log.Errorf("failed sending new decisions for startup: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) - //Expired decisions + gctx.Writer.WriteString(`], "deleted": [`) + // Expired decisions err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters) if err != nil { log.Errorf("failed sending expired decisions for startup: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() } else { err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryNewDecisionsSinceWithFilters) if err != nil { log.Errorf("failed sending new decisions for delta: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) + gctx.Writer.WriteString(`], "deleted": [`) err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryExpiredDecisionsSinceWithFilters) - if err != nil { log.Errorf("failed sending expired decisions for delta: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() } @@ -301,8 +320,12 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B } func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { - var data []*ent.Decision - var err error + var ( + data []*ent.Decision + err error + ) + + ctx := gctx.Request.Context() ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} @@ -310,18 +333,18 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err = c.DBClient.QueryAllDecisionsWithFilters(filters) + data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -338,18 +361,23 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) + since := time.Time{} + if bouncerInfo.LastPull != nil { + since = bouncerInfo.LastPull.Add(-2 * time.Second) + } + // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -366,6 +394,8 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en func (c *Controller) StreamDecision(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + streamStartTime := time.Now().UTC() bouncerInfo, err := getBouncerFromContext(gctx) @@ -376,8 +406,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } if gctx.Request.Method == http.MethodHead { - //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db - //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) + // For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db + // We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) gctx.String(http.StatusOK, "") return @@ -395,8 +425,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } if err == nil { - //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + // Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions + if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/controllers/v1/errors.go b/pkg/apiserver/controllers/v1/errors.go index b85b811f8a7..d661de44b0e 100644 --- a/pkg/apiserver/controllers/v1/errors.go +++ b/pkg/apiserver/controllers/v1/errors.go @@ -1,35 +1,36 @@ package v1 import ( + "errors" "net/http" + "strings" "github.com/gin-gonic/gin" - "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database" ) func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { - switch errors.Cause(err) { - case database.ItemNotFound: + switch { + case errors.Is(err, database.ItemNotFound): gctx.JSON(http.StatusNotFound, gin.H{"message": err.Error()}) return - case database.UserExists: + case errors.Is(err, database.UserExists): gctx.JSON(http.StatusForbidden, gin.H{"message": err.Error()}) return - case database.HashError: + case errors.Is(err, database.HashError): gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return - case database.InsertFail: + case errors.Is(err, database.InsertFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.QueryFail: + case errors.Is(err, database.QueryFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.ParseTimeFail: + case errors.Is(err, database.ParseTimeFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return - case database.ParseDurationFail: + case errors.Is(err, database.ParseDurationFail): gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return default: @@ -37,3 +38,32 @@ func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { return } } + +// collapseRepeatedPrefix collapses repeated occurrences of a given prefix in the text +func collapseRepeatedPrefix(text string, prefix string) string { + count := 0 + for strings.HasPrefix(text, prefix) { + count++ + text = strings.TrimPrefix(text, prefix) + } + + if count > 0 { + return prefix + text + } + + return text +} + +// RepeatedPrefixError wraps an error and removes the repeating prefix from its message +type RepeatedPrefixError struct { + OriginalError error + Prefix string +} + +func (e RepeatedPrefixError) Error() string { + return collapseRepeatedPrefix(e.OriginalError.Error(), e.Prefix) +} + +func (e RepeatedPrefixError) Unwrap() error { + return e.OriginalError +} diff --git a/pkg/apiserver/controllers/v1/errors_test.go b/pkg/apiserver/controllers/v1/errors_test.go new file mode 100644 index 00000000000..89c561f83bd --- /dev/null +++ b/pkg/apiserver/controllers/v1/errors_test.go @@ -0,0 +1,57 @@ +package v1 + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollapseRepeatedPrefix(t *testing.T) { + tests := []struct { + input string + prefix string + want string + }{ + { + input: "aaabbbcccaaa", + prefix: "aaa", + want: "aaabbbcccaaa", + }, { + input: "hellohellohello world", + prefix: "hello", + want: "hello world", + }, { + input: "ababababxyz", + prefix: "ab", + want: "abxyz", + }, { + input: "xyzxyzxyzxyzxyz", + prefix: "xyz", + want: "xyz", + }, { + input: "123123123456", + prefix: "456", + want: "123123123456", + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, collapseRepeatedPrefix(tt.input, tt.prefix)) + }) + } +} + +func TestRepeatedPrefixError(t *testing.T) { + originalErr := errors.New("hellohellohello world") + wrappedErr := RepeatedPrefixError{OriginalError: originalErr, Prefix: "hello"} + + want := "hello world" + + assert.Equal(t, want, wrappedErr.Error()) + + assert.Equal(t, originalErr, errors.Unwrap(wrappedErr)) + require.ErrorIs(t, wrappedErr, originalErr) +} diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index b19b450f0d5..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -3,16 +3,15 @@ package v1 import ( "net/http" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + ctx := gctx.Request.Context() + + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 84a6ef2583c..ff59e389cb1 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -1,16 +1,53 @@ package v1 import ( + "errors" + "net" "net/http" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) +func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, error) { + if !*c.AutoRegisterCfg.Enable { + return false, nil + } + + clientIP := net.ParseIP(gctx.ClientIP()) + + // Can probaby happen if using unix socket ? + if clientIP == nil { + log.Warnf("Failed to parse client IP for watcher self registration: %s", gctx.ClientIP()) + return false, nil + } + + if token == "" || c.AutoRegisterCfg == nil { + return false, nil + } + + // Check the token + if token != c.AutoRegisterCfg.Token { + return false, errors.New("invalid token for auto registration") + } + + // Check the source IP + for _, ipRange := range c.AutoRegisterCfg.AllowedRangesParsed { + if ipRange.Contains(clientIP) { + return true, nil + } + } + + return false, errors.New("IP not in allowed range for auto registration") +} + func (c *Controller) CreateMachine(gctx *gin.Context) { + ctx := gctx.Request.Context() + var input models.WatcherRegistrationRequest if err := gctx.ShouldBindJSON(&input); err != nil { @@ -19,14 +56,27 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { } if err := input.Validate(strfmt.Default); err != nil { - c.HandleDBErrors(gctx, err) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) return } - if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil { + autoRegister, err := c.shouldAutoRegister(input.RegistrationToken, gctx) + if err != nil { + log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Errorf("Auto-register failed: %s", err) + gctx.JSON(http.StatusUnauthorized, gin.H{"message": err.Error()}) + + return + } + + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } - gctx.Status(http.StatusCreated) + if autoRegister { + log.WithFields(log.Fields{"ip": gctx.ClientIP(), "machine_id": *input.MachineID}).Info("Auto-registered machine") + gctx.Status(http.StatusAccepted) + } else { + gctx.Status(http.StatusCreated) + } } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 13ccf9ac94f..4f6ee0986eb 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -3,7 +3,6 @@ package v1 import ( "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) @@ -66,32 +65,32 @@ var LapiResponseTime = prometheus.NewHistogramVec( []string{"endpoint", "method"}) func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } func PrometheusMachinesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - claims := jwt.ExtractClaims(c) - if claims != nil { - if rawID, ok := claims["id"]; ok { - machineID := rawID.(string) - LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() - } + machineID, _ := getMachineIDFromContext(c) + if machineID != "" { + LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": c.Request.URL.Path, + "method": c.Request.Method, + }).Inc() } c.Next() @@ -100,12 +99,13 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { func PrometheusBouncersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiBouncerHits.With(prometheus.Labels{ - "bouncer": name.(string), + "bouncer": bouncer.Name, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } c.Next() @@ -118,7 +118,8 @@ func PrometheusMiddleware() gin.HandlerFunc { LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() c.Next() elapsed := time.Since(startTime) diff --git a/pkg/apiserver/controllers/v1/usagemetrics.go b/pkg/apiserver/controllers/v1/usagemetrics.go new file mode 100644 index 00000000000..5b2c3e3b1a9 --- /dev/null +++ b/pkg/apiserver/controllers/v1/usagemetrics.go @@ -0,0 +1,205 @@ +package v1 + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +// updateBaseMetrics updates the base metrics for a machine or bouncer +func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { + switch { + case machineID != "": + return c.DBClient.MachineUpdateBaseMetrics(ctx, machineID, baseMetrics, hubItems, datasources) + case bouncer != nil: + return c.DBClient.BouncerUpdateBaseMetrics(ctx, bouncer.Name, bouncer.Type, baseMetrics) + default: + return errors.New("no machineID or bouncerName set") + } +} + +// UsageMetrics receives metrics from log processors and remediation components +func (c *Controller) UsageMetrics(gctx *gin.Context) { + var input models.AllMetrics + + logger := log.WithField("func", "UsageMetrics") + + // parse the payload + + if err := gctx.ShouldBindJSON(&input); err != nil { + logger.Errorf("Failed to bind json: %s", err) + gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + + return + } + + if err := input.Validate(strfmt.Default); err != nil { + // work around a nuisance in the generated code + cleanErr := RepeatedPrefixError{ + OriginalError: err, + Prefix: "validation failure list:\n", + } + logger.Errorf("Failed to validate usage metrics: %s", cleanErr) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": cleanErr.Error()}) + + return + } + + var ( + generatedType metric.GeneratedType + generatedBy string + ) + + bouncer, _ := getBouncerFromContext(gctx) + if bouncer != nil { + logger.Tracef("Received usage metris for bouncer: %s", bouncer.Name) + + generatedType = metric.GeneratedTypeRC + generatedBy = bouncer.Name + } + + machineID, _ := getMachineIDFromContext(gctx) + if machineID != "" { + logger.Tracef("Received usage metrics for log processor: %s", machineID) + + generatedType = metric.GeneratedTypeLP + generatedBy = machineID + } + + if generatedBy == "" { + // how did we get here? + logger.Error("No machineID or bouncer in request context after authentication") + gctx.JSON(http.StatusInternalServerError, gin.H{"message": "No machineID or bouncer in request context after authentication"}) + + return + } + + if machineID != "" && bouncer != nil { + logger.Errorf("Payload has both machineID and bouncer") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has both LP and RC data"}) + + return + } + + var ( + payload map[string]any + baseMetrics models.BaseMetrics + hubItems models.HubItems + datasources map[string]int64 + ) + + switch len(input.LogProcessors) { + case 0: + if machineID != "" { + logger.Errorf("Missing log processor data") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing log processor data"}) + + return + } + case 1: + // the final slice can't have more than one item, + // guaranteed by the swagger schema + item0 := input.LogProcessors[0] + + err := item0.Validate(strfmt.Default) + if err != nil { + logger.Errorf("Failed to validate log processor data: %s", err) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + + return + } + + payload = map[string]any{ + "metrics": item0.Metrics, + } + baseMetrics = item0.BaseMetrics + hubItems = item0.HubItems + datasources = item0.Datasources + default: + logger.Errorf("Payload has more than one log processor") + // this is not checked in the swagger schema + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one log processor"}) + + return + } + + switch len(input.RemediationComponents) { + case 0: + if bouncer != nil { + logger.Errorf("Missing remediation component data") + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Missing remediation component data"}) + + return + } + case 1: + item0 := input.RemediationComponents[0] + + err := item0.Validate(strfmt.Default) + if err != nil { + logger.Errorf("Failed to validate remediation component data: %s", err) + gctx.JSON(http.StatusUnprocessableEntity, gin.H{"message": err.Error()}) + + return + } + + payload = map[string]any{ + "type": item0.Type, + "metrics": item0.Metrics, + } + baseMetrics = item0.BaseMetrics + default: + gctx.JSON(http.StatusBadRequest, gin.H{"message": "Payload has more than one remediation component"}) + return + } + + if baseMetrics.Os == nil { + baseMetrics.Os = &models.OSversion{ + Name: ptr.Of(""), + Version: ptr.Of(""), + } + } + + ctx := gctx.Request.Context() + + err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources) + if err != nil { + logger.Errorf("Failed to update base metrics: %s", err) + c.HandleDBErrors(gctx, err) + + return + } + + jsonPayload, err := json.Marshal(payload) + if err != nil { + logger.Errorf("Failed to serialize usage metrics: %s", err) + c.HandleDBErrors(gctx, err) + + return + } + + receivedAt := time.Now().UTC() + + if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { + logger.Error(err) + c.HandleDBErrors(gctx, err) + + return + } + + // if CreateMetrics() returned nil, the metric was already there, we're good + // and don't split hair about 201 vs 200/204 + + gctx.Status(http.StatusCreated) +} diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 6afd005132a..3cd53d217cc 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -1,34 +1,72 @@ package v1 import ( - "fmt" + "errors" + "net" "net/http" + "strings" + jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -const bouncerContextKey = "bouncer_info" - func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(bouncerContextKey) + bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) if !exist { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) if !ok { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } return bouncerInfo, nil } +func isUnixSocket(c *gin.Context) bool { + if localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { + return strings.HasPrefix(localAddr.Network(), "unix") + } + + return false +} + +func getMachineIDFromContext(ctx *gin.Context) (string, error) { + claims := jwt.ExtractClaims(ctx) + if claims == nil { + return "", errors.New("failed to extract claims") + } + + rawID, ok := claims[middlewares.MachineIDKey] + if !ok { + return "", errors.New("MachineID not found in claims") + } + + id, ok := rawID.(string) + if !ok { + // should never happen + return "", errors.New("failed to cast machineID to string") + } + + return id, nil +} + func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { return func(gctx *gin.Context) { + if !option { + return + } + + if isUnixSocket(gctx) { + return + } + incomingIP := gctx.ClientIP() - if option && incomingIP != "127.0.0.1" && incomingIP != "::1" { + if incomingIP != "127.0.0.1" && incomingIP != "::1" { gctx.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) gctx.Abort() } diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index e4c9dda47ce..a0af6956443 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -12,82 +13,86 @@ const ( ) func TestDeleteDecisionRange(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } func TestDeleteDecisionFilter(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteDecisionFilterByScenario(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by wrong scenario - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) } func TestGetDecisionFilters(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // Get Decision - w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -101,7 +106,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -118,7 +123,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -132,7 +137,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -145,7 +150,7 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -155,13 +160,14 @@ func TestGetDecisionFilters(t *testing.T) { } func TestGetDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -180,51 +186,52 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) assert.Len(t, decisions, 3) } func TestDeleteDecisionByID(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") - //Have one alerts - w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) errResp, _ = readDecisionsErrorResp(t, w) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) - //Have one alerts - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with valid ID - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "1", resp.NbDeleted) - //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert (because we delete an alert that has dup targets) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -232,33 +239,35 @@ func TestDeleteDecisionByID(t *testing.T) { } func TestDeleteDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "3", resp.NbDeleted) } func TestStreamStartDecisionDedup(t *testing.T) { - //Ensure that at stream startup we only get the longest decision - lapi := SetupLAPITest(t) + ctx := context.Background() + // Ensure that at stream startup we only get the longest decision + lapi := SetupLAPITest(t, ctx) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -268,11 +277,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -282,11 +291,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -296,11 +305,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - //and now we only get a deleted decision - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // and now we only get a deleted decision + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions["deleted"], 1) diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index fbf01c7fb8e..db051566f75 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "testing" @@ -8,11 +9,12 @@ import ( ) func TestHeartBeat(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) - w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 58f66cfc74f..f6f51763975 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -10,13 +11,14 @@ import ( ) func TestLogin(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) - body := CreateTestMachine(t, router) + body := CreateTestMachine(t, ctx, router, "") // Login with machine not validated yet w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +27,7 @@ func TestLogin(t *testing.T) { // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -34,7 +36,7 @@ func TestLogin(t *testing.T) { // Login with invalid body w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -43,19 +45,19 @@ func TestLogin(t *testing.T) { // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) - //Validate machine - ValidateMachine(t, "test", config.API.Server.DbConfig) + // Validate machine + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -64,7 +66,7 @@ func TestLogin(t *testing.T) { // Login with valid machine w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -74,7 +76,7 @@ func TestLogin(t *testing.T) { // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 08efa91c6c1..969f75707d6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,27 +10,30 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestCreateMachine(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Create machine with invalid format w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("test")) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 400, w.Code) + assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusUnprocessableEntity, w.Code) assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) // Create machine @@ -39,17 +43,19 @@ func TestCreateMachine(t *testing.T) { body := string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) router.TrustedPlatform = "X-Real-IP" + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -57,12 +63,12 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-Ip", "1.1.1.1") router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) @@ -71,7 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -80,23 +87,24 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-IP", "1.1.1.1") router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -105,35 +113,121 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, "", w.Body.String()) ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) - body := CreateTestMachine(t, router) + body := CreateTestMachine(t, ctx, router, "") w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) } + +func TestAutoRegistration(t *testing.T) { + ctx := context.Background() + router, _ := NewAPITest(t, ctx) + + // Invalid registration token / valid source IP + regReq := MachineTest + regReq.RegistrationToken = invalidRegistrationToken + b, err := json.Marshal(regReq) + require.NoError(t, err) + + body := string(b) + + w := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // Invalid registration token / invalid source IP + regReq = MachineTest + regReq.RegistrationToken = invalidRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "42.42.42.42:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // valid registration token / invalid source IP + regReq = MachineTest + regReq.RegistrationToken = validRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "42.42.42.42:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + + // Valid registration token / valid source IP + regReq = MachineTest + regReq.RegistrationToken = validRegistrationToken + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusAccepted, w.Code) + + // No token / valid source IP + regReq = MachineTest + regReq.MachineID = ptr.Of("test2") + b, err = json.Marshal(regReq) + require.NoError(t, err) + + body = string(b) + + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:4242" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) +} diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index ae7645e1b85..d438c9b15a4 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -18,9 +18,9 @@ import ( const ( APIKeyHeader = "X-Api-Key" - bouncerContextKey = "bouncer_info" + BouncerContextKey = "bouncer_info" + dummyAPIKeySize = 54 // max allowed by bcrypt 72 = 54 bytes in base64 - dummyAPIKeySize = 54 ) type APIKey struct { @@ -60,32 +60,27 @@ func HashSHA512(str string) string { func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { if a.TlsAuth == nil { - logger.Error("TLS Auth is not configured but client presented a certificate") + logger.Warn("TLS Auth is not configured but client presented a certificate") return nil } - validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) - if !validCert { - logger.Errorf("invalid client certificate: %s", err) - return nil - } + ctx := c.Request.Context() + extractedCN, err := a.TlsAuth.ValidateCert(c) if err != nil { - logger.Error(err) + logger.Warn(err) return nil } - logger = logger.WithFields(log.Fields{ - "cn": extractedCN, - }) + logger = logger.WithField("cn", extractedCN) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) - //This is likely not the proper way, but isNotFound does not seem to work + // This is likely not the proper way, but isNotFound does not seem to work if err != nil && strings.Contains(err.Error(), "bouncer not found") { - //Because we have a valid cert, automatically create the bouncer in the database if it does not exist - //Set a random API key, but it will never be used + // Because we have a valid cert, automatically create the bouncer in the database if it does not exist + // Set a random API key, but it will never be used apiKey, err := GenerateAPIKey(dummyAPIKeySize) if err != nil { logger.Errorf("error generating mock api key: %s", err) @@ -94,17 +89,17 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil } } else if err != nil { - //error while selecting bouncer + // error while selecting bouncer logger.Errorf("while selecting bouncers: %s", err) return nil } else if bouncer.AuthType != types.TlsAuthType { - //bouncer was found in DB + // bouncer was found in DB logger.Errorf("bouncer isn't allowed to auth by TLS") return nil } @@ -119,9 +114,11 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) + bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil @@ -139,9 +136,11 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer - logger := log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }) + ctx := c.Request.Context() + + clientIP := c.ClientIP() + + logger := log.WithField("ip", clientIP) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { bouncer = a.authTLS(c, logger) @@ -150,22 +149,17 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer == nil { + // XXX: StatusUnauthorized? c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } - logger = logger.WithFields(log.Fields{ - "name": bouncer.Name, - }) - - // maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - // in StreamDecision - c.Set("BOUNCER_NAME", bouncer.Name) - c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) + logger = logger.WithField("name", bouncer.Name) if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -174,11 +168,11 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - //Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided - if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { - log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) + // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided + if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { + log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) - if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -194,7 +188,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() @@ -203,7 +197,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - c.Set(bouncerContextKey, bouncer) - c.Next() + c.Set(BouncerContextKey, bouncer) } } diff --git a/pkg/apiserver/middlewares/v1/cache.go b/pkg/apiserver/middlewares/v1/cache.go new file mode 100644 index 00000000000..b0037bc4fa4 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/cache.go @@ -0,0 +1,99 @@ +package v1 + +import ( + "crypto/x509" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type cacheEntry struct { + err error // if nil, the certificate is not revocated + timestamp time.Time +} + +type RevocationCache struct { + mu sync.RWMutex + cache map[string]cacheEntry + expiration time.Duration + lastPurge time.Time + logger *log.Entry +} + +func NewRevocationCache(expiration time.Duration, logger *log.Entry) *RevocationCache { + return &RevocationCache{ + cache: make(map[string]cacheEntry), + expiration: expiration, + lastPurge: time.Now(), + logger: logger, + } +} + +func (*RevocationCache) generateKey(cert *x509.Certificate) string { + return cert.SerialNumber.String() + "-" + cert.Issuer.String() +} + +// purge removes expired entries from the cache +func (rc *RevocationCache) purgeExpired() { + // we don't keep a separate interval for the full sweep, we'll just double the expiration + if time.Since(rc.lastPurge) < rc.expiration { + return + } + + rc.mu.Lock() + defer rc.mu.Unlock() + + for key, entry := range rc.cache { + if time.Since(entry.timestamp) > rc.expiration { + rc.logger.Debugf("purging expired entry for cert %s", key) + delete(rc.cache, key) + } + } +} + +func (rc *RevocationCache) Get(cert *x509.Certificate) (error, bool) { //nolint:revive + rc.purgeExpired() + key := rc.generateKey(cert) + rc.mu.RLock() + entry, exists := rc.cache[key] + rc.mu.RUnlock() + + if !exists { + rc.logger.Tracef("no cached value for cert %s", key) + return nil, false + } + + // Upgrade to write lock to potentially modify the cache + rc.mu.Lock() + defer rc.mu.Unlock() + + if entry.timestamp.Add(rc.expiration).Before(time.Now()) { + rc.logger.Debugf("cached value for %s expired, removing from cache", key) + delete(rc.cache, key) + + return nil, false + } + + rc.logger.Debugf("using cached value for cert %s: %v", key, entry.err) + + return entry.err, true +} + +func (rc *RevocationCache) Set(cert *x509.Certificate, err error) { + key := rc.generateKey(cert) + + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.cache[key] = cacheEntry{ + err: err, + timestamp: time.Now(), + } +} + +func (rc *RevocationCache) Empty() { + rc.mu.Lock() + defer rc.mu.Unlock() + rc.cache = make(map[string]cacheEntry) +} diff --git a/pkg/apiserver/middlewares/v1/crl.go b/pkg/apiserver/middlewares/v1/crl.go new file mode 100644 index 00000000000..64d7d3f0d96 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/crl.go @@ -0,0 +1,145 @@ +package v1 + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type CRLChecker struct { + path string // path to the CRL file + fileInfo os.FileInfo // last stat of the CRL file + crls []*x509.RevocationList // parsed CRLs + logger *log.Entry + mu sync.RWMutex + lastLoad time.Time // time when the CRL file was last read successfully + onLoad func() // called when the CRL file changes (and is read successfully) +} + +func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) { + cc := &CRLChecker{ + path: crlPath, + logger: logger, + onLoad: onLoad, + } + + err := cc.refresh() + if err != nil { + return nil, err + } + + return cc, nil +} + +func (*CRLChecker) decodeCRLs(content []byte) ([]*x509.RevocationList, error) { + var crls []*x509.RevocationList + + for { + block, rest := pem.Decode(content) + if block == nil { + break // no more PEM blocks + } + + content = rest + + crl, err := x509.ParseRevocationList(block.Bytes) + if err != nil { + // invalidate the whole CRL file so we can still use the previous version + return nil, fmt.Errorf("could not parse file: %w", err) + } + + crls = append(crls, crl) + } + + return crls, nil +} + +// refresh() reads the CRL file if new or changed since the last time +func (cc *CRLChecker) refresh() error { + // noop if lastLoad is less than 5 seconds ago + if time.Since(cc.lastLoad) < 5*time.Second { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + cc.logger.Debugf("loading CRL file from %s", cc.path) + + fileInfo, err := os.Stat(cc.path) + if err != nil { + return fmt.Errorf("could not access CRL file: %w", err) + } + + // noop if the file didn't change + if cc.fileInfo != nil && fileInfo.ModTime().Equal(cc.fileInfo.ModTime()) && fileInfo.Size() == cc.fileInfo.Size() { + return nil + } + + // the encoding/pem package wants bytes, not io.Reader + crlContent, err := os.ReadFile(cc.path) + if err != nil { + return fmt.Errorf("could not read CRL file: %w", err) + } + + cc.crls, err = cc.decodeCRLs(crlContent) + if err != nil { + return err + } + + cc.fileInfo = fileInfo + cc.lastLoad = time.Now() + cc.onLoad() + + return nil +} + +// isRevoked checks if the client certificate is revoked by any of the CRL blocks +// It returns a boolean indicating if the certificate is revoked and a boolean indicating +// if the CRL check was successful and could be cached. +func (cc *CRLChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { + if cc == nil { + return false, true + } + + err := cc.refresh() + if err != nil { + // we can't quit obviously, so we just log the error and continue + // but we can assume we have loaded a CRL, or it would have quit the first time + cc.logger.Errorf("while refreshing CRL: %s - will keep using CRL file read at %s", err, + cc.lastLoad.Format(time.RFC3339)) + } + + now := time.Now().UTC() + + cc.mu.RLock() + defer cc.mu.RUnlock() + + for _, crl := range cc.crls { + if err := crl.CheckSignatureFrom(issuer); err != nil { + continue + } + + if now.After(crl.NextUpdate) { + cc.logger.Warn("CRL has expired, will still validate the cert against it.") + } + + if now.Before(crl.ThisUpdate) { + cc.logger.Warn("CRL is not yet valid, will still validate the cert against it.") + } + + for _, revoked := range crl.RevokedCertificateEntries { + if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { + cc.logger.Warn("client certificate is revoked by CRL") + return true, true + } + } + } + + return false, true +} diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index ed4ad107b96..9171e9fce06 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "errors" "fmt" - "net/http" "os" "strings" "time" @@ -22,7 +21,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var identityKey = "id" +const MachineIDKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware @@ -33,7 +32,7 @@ type JWT struct { func PayloadFunc(data interface{}) jwt.MapClaims { if value, ok := data.(*models.WatcherAuthRequest); ok { return jwt.MapClaims{ - identityKey: &value.MachineID, + MachineIDKey: &value.MachineID, } } @@ -42,7 +41,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims { func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineID := claims[identityKey].(string) + machineID := claims[MachineIDKey].(string) return &models.WatcherAuthRequest{ MachineID: &machineID, @@ -56,51 +55,44 @@ type authInput struct { } func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + ctx := c.Request.Context() ret := authInput{} if j.TlsAuth == nil { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, errors.New("TLS auth is not configured") + err := errors.New("tls authentication required") + log.Warn(err) + + return nil, err } - validCert, extractedCN, err := j.TlsAuth.ValidateCert(c) + extractedCN, err := j.TlsAuth.ValidateCert(c) if err != nil { - log.Error(err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - - return nil, fmt.Errorf("while trying to validate client cert: %w", err) + log.Warn(err) + return nil, err } - if !validCert { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, fmt.Errorf("failed cert authentication") - } + logger := log.WithField("ip", c.ClientIP()) ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if ent.IsNotFound(err) { - //Machine was not found, let's create it - log.Infof("machine %s not found, create it", ret.machineID) - //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) + // Machine was not found, let's create it + logger.Infof("machine %s not found, create it", ret.machineID) + // let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) pwd, err := GenerateAPIKey(dummyAPIKeySize) if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("error generating password: %s", err) + logger.WithField("cn", extractedCN). + Errorf("error generating password: %s", err) - return nil, fmt.Errorf("error generating password") + return nil, errors.New("error generating password") } password := strfmt.Password(pwd) - ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } @@ -110,6 +102,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { if ret.clientMachine.AuthType != types.TlsAuthType { return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) } + ret.machineID = ret.clientMachine.MachineId } @@ -135,6 +128,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { err error ) + ctx := c.Request.Context() + ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { @@ -151,7 +146,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err @@ -183,6 +178,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { auth *authInput ) + ctx := c.Request.Context() + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) if err != nil { @@ -206,25 +203,27 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } + clientIP := c.ClientIP() + if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } - if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { - log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) + if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { + log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -233,13 +232,14 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { useragent := strings.Split(c.Request.UserAgent(), "/") if len(useragent) != 2 { - log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP()) + log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), clientIP) return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) - log.Errorf("bad user agent from : %s", c.ClientIP()) + log.Errorf("bad user agent from : %s", clientIP) + return nil, jwt.ErrFailedAuthentication } @@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { Key: secret, Timeout: time.Hour, MaxRefresh: time.Hour, - IdentityKey: identityKey, + IdentityKey: MachineIDKey, PayloadFunc: PayloadFunc, IdentityHandler: IdentityHandler, Authenticator: jwtMiddleware.Authenticator, @@ -323,8 +323,9 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { errInit := ret.MiddlewareInit() if errInit != nil { - return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) + return &JWT{}, errors.New("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) } + jwtMiddleware.Middleware = ret return jwtMiddleware, nil diff --git a/pkg/apiserver/middlewares/v1/ocsp.go b/pkg/apiserver/middlewares/v1/ocsp.go new file mode 100644 index 00000000000..0b6406ad0e7 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/ocsp.go @@ -0,0 +1,100 @@ +package v1 + +import ( + "bytes" + "crypto" + "crypto/x509" + "io" + "net/http" + "net/url" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ocsp" +) + +type OCSPChecker struct { + logger *log.Entry +} + +func NewOCSPChecker(logger *log.Entry) *OCSPChecker { + return &OCSPChecker{ + logger: logger, + } +} + +func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) { + req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) + if err != nil { + oc.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) + return nil, err + } + + httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) + if err != nil { + oc.logger.Error("TLSAuth: cannot create HTTP request for OCSP") + return nil, err + } + + ocspURL, err := url.Parse(server) + if err != nil { + oc.logger.Error("TLSAuth: cannot parse OCSP URL") + return nil, err + } + + httpRequest.Header.Add("Content-Type", "application/ocsp-request") + httpRequest.Header.Add("Accept", "application/ocsp-response") + httpRequest.Header.Add("Host", ocspURL.Host) + + httpClient := &http.Client{} + + // XXX: timeout, context? + httpResponse, err := httpClient.Do(httpRequest) + if err != nil { + oc.logger.Error("TLSAuth: cannot send HTTP request to OCSP") + return nil, err + } + defer httpResponse.Body.Close() + + output, err := io.ReadAll(httpResponse.Body) + if err != nil { + oc.logger.Error("TLSAuth: cannot read HTTP response from OCSP") + return nil, err + } + + ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) + + return ocspResponse, err +} + +// isRevokedBy checks if the client certificate is revoked by the issuer via any of the OCSP servers present in the certificate. +// It returns a boolean indicating if the certificate is revoked and a boolean indicating +// if the OCSP check was successful and could be cached. +func (oc *OCSPChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { + if len(cert.OCSPServer) == 0 { + oc.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") + return false, true + } + + for _, server := range cert.OCSPServer { + ocspResponse, err := oc.query(server, cert, issuer) + if err != nil { + oc.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) + continue + } + + switch ocspResponse.Status { + case ocsp.Good: + return false, true + case ocsp.Revoked: + oc.logger.Errorf("TLSAuth: client certificate is revoked by server %s", server) + return true, true + case ocsp.Unknown: + log.Debugf("unknown OCSP status for server %s", server) + continue + } + } + + log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") + + return true, false +} diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index 904f6cd445a..673c8d0cdce 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -1,78 +1,24 @@ package v1 import ( - "bytes" - "crypto" "crypto/x509" + "errors" "fmt" - "io" - "net/http" - "net/url" - "os" + "slices" "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ocsp" ) type TLSAuth struct { AllowedOUs []string - CrlPath string - revokationCache map[string]cacheEntry - cacheExpiration time.Duration + crlChecker *CRLChecker + ocspChecker *OCSPChecker + revocationCache *RevocationCache logger *log.Entry } -type cacheEntry struct { - revoked bool - err error - timestamp time.Time -} - -func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) { - req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) - if err != nil { - ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) - return nil, err - } - - httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) - if err != nil { - ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") - return nil, err - } - - ocspURL, err := url.Parse(server) - if err != nil { - ta.logger.Error("TLSAuth: cannot parse OCSP URL") - return nil, err - } - - httpRequest.Header.Add("Content-Type", "application/ocsp-request") - httpRequest.Header.Add("Accept", "application/ocsp-response") - httpRequest.Header.Add("host", ocspURL.Host) - - httpClient := &http.Client{} - - httpResponse, err := httpClient.Do(httpRequest) - if err != nil { - ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") - return nil, err - } - defer httpResponse.Body.Close() - - output, err := io.ReadAll(httpResponse.Body) - if err != nil { - ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") - return nil, err - } - - ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) - - return ocspResponse, err -} - func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { now := time.Now().UTC() @@ -89,207 +35,147 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { return false } -func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) { - ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") - return false, nil - } +// checkRevocationPath checks a single chain against OCSP and CRL +func (ta *TLSAuth) checkRevocationPath(chain []*x509.Certificate) (error, bool) { //nolint:revive + // if we ever fail to check OCSP or CRL, we should not cache the result + couldCheck := true - for _, server := range cert.OCSPServer { - ocspResponse, err := ta.ocspQuery(server, cert, issuer) - if err != nil { - ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) - continue + // starting from the root CA and moving towards the leaf certificate, + // check for revocation of intermediates too + for i := len(chain) - 1; i > 0; i-- { + cert := chain[i-1] + issuer := chain[i] + + revokedByOCSP, checkedByOCSP := ta.ocspChecker.isRevokedBy(cert, issuer) + couldCheck = couldCheck && checkedByOCSP + + if revokedByOCSP && checkedByOCSP { + return errors.New("certificate revoked by OCSP"), couldCheck } - switch ocspResponse.Status { - case ocsp.Good: - return false, nil - case ocsp.Revoked: - return true, fmt.Errorf("client certificate is revoked by server %s", server) - case ocsp.Unknown: - log.Debugf("unknow OCSP status for server %s", server) - continue + revokedByCRL, checkedByCRL := ta.crlChecker.isRevokedBy(cert, issuer) + couldCheck = couldCheck && checkedByCRL + + if revokedByCRL && checkedByCRL { + return errors.New("certificate revoked by CRL"), couldCheck } } - log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") - - return true, nil + return nil, couldCheck } -func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { - if ta.CrlPath == "" { - ta.logger.Warn("no crl_path, skipping CRL check") - return false, nil - } +func (ta *TLSAuth) setAllowedOu(allowedOus []string) error { + uniqueOUs := make(map[string]struct{}) - crlContent, err := os.ReadFile(ta.CrlPath) - if err != nil { - ta.logger.Warnf("could not read CRL file, skipping check: %s", err) - return false, nil - } + for _, ou := range allowedOus { + // disallow empty ou + if ou == "" { + return errors.New("allowed_ou configuration contains invalid empty string") + } - crl, err := x509.ParseCRL(crlContent) - if err != nil { - ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) - return false, nil - } + if _, exists := uniqueOUs[ou]; exists { + ta.logger.Warningf("dropping duplicate ou %s", ou) + continue + } - if crl.HasExpired(time.Now().UTC()) { - ta.logger.Warn("CRL has expired, will still validate the cert against it.") - } + uniqueOUs[ou] = struct{}{} - for _, revoked := range crl.TBSCertList.RevokedCertificates { - if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { - return true, fmt.Errorf("client certificate is revoked by CRL") - } + ta.AllowedOUs = append(ta.AllowedOUs, ou) } - return false, nil + return nil } -func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - sn := cert.SerialNumber.String() - if cacheValue, ok := ta.revokationCache[sn]; ok { - if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration { - ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err) - return cacheValue.revoked, cacheValue.err - } else { - ta.logger.Debugf("TLSAuth: cached value expired, removing from cache") - delete(ta.revokationCache, sn) +func (ta *TLSAuth) checkAllowedOU(ous []string) error { + for _, ou := range ous { + if slices.Contains(ta.AllowedOUs, ou) { + return nil } - } else { - ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) } - revoked, err := ta.isOCSPRevoked(cert, issuer) - if err != nil { - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), - } - - return true, err - } + return fmt.Errorf("client certificate OU %v doesn't match expected OU %v", ous, ta.AllowedOUs) +} - if revoked { - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), - } +func (ta *TLSAuth) ValidateCert(c *gin.Context) (string, error) { + // Checks cert validity, Returns true + CN if client cert matches requested OU + var leaf *x509.Certificate - return true, nil + if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { + return "", errors.New("no certificate in request") } - revoked, err = ta.isCRLRevoked(cert) - ta.revokationCache[sn] = cacheEntry{ - revoked: revoked, - err: err, - timestamp: time.Now().UTC(), + if len(c.Request.TLS.VerifiedChains) == 0 { + return "", errors.New("no verified cert in request") } - return revoked, err -} + // although there can be multiple chains, the leaf certificate is the same + // we take the first one + leaf = c.Request.TLS.VerifiedChains[0][0] -func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { - if ta.isExpired(cert) { - return true, nil + if err := ta.checkAllowedOU(leaf.Subject.OrganizationalUnit); err != nil { + return "", err } - revoked, err := ta.isRevoked(cert, issuer) - if err != nil { - //Fail securely, if we can't check the revocation status, let's consider the cert invalid - //We may change this in the future based on users feedback, but this seems the most sensible thing to do - return true, fmt.Errorf("could not check for client certification revocation status: %w", err) + if ta.isExpired(leaf) { + return "", errors.New("client certificate is expired") } - return revoked, nil -} - -func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error { - for _, ou := range allowedOus { - //disallow empty ou - if ou == "" { - return fmt.Errorf("empty ou isn't allowed") + if validErr, cached := ta.revocationCache.Get(leaf); cached { + if validErr != nil { + return "", fmt.Errorf("(cache) %w", validErr) } - //drop & warn on duplicate ou - ok := true - - for _, validOu := range ta.AllowedOUs { - if validOu == ou { - ta.logger.Warningf("dropping duplicate ou %s", ou) - - ok = false - } - } - - if ok { - ta.AllowedOUs = append(ta.AllowedOUs, ou) - } + return leaf.Subject.CommonName, nil } - return nil -} + okToCache := true -func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { - //Checks cert validity, Returns true + CN if client cert matches requested OU - var clientCert *x509.Certificate + var validErr error - if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { - //do not error if it's not TLS or there are no peer certs - return false, "", nil - } + var couldCheck bool - if len(c.Request.TLS.VerifiedChains) > 0 { - validOU := false - clientCert = c.Request.TLS.VerifiedChains[0][0] - - for _, ou := range clientCert.Subject.OrganizationalUnit { - for _, allowedOu := range ta.AllowedOUs { - if allowedOu == ou { - validOU = true - break - } - } - } - - if !validOU { - return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", - clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) - } - - revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) - if err != nil { - ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) - return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) - } + for _, chain := range c.Request.TLS.VerifiedChains { + validErr, couldCheck = ta.checkRevocationPath(chain) + okToCache = okToCache && couldCheck - if revoked { - return false, "", fmt.Errorf("client certificate is revoked") + if validErr != nil { + break } + } - ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) + if okToCache { + ta.revocationCache.Set(leaf, validErr) + } - return true, clientCert.Subject.CommonName, nil + if validErr != nil { + return "", validErr } - return false, "", fmt.Errorf("no verified cert in request") + return leaf.Subject.CommonName, nil } func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { + var err error + + cache := NewRevocationCache(cacheExpiration, logger) + ta := &TLSAuth{ - revokationCache: map[string]cacheEntry{}, - cacheExpiration: cacheExpiration, - CrlPath: crlPath, + revocationCache: cache, + ocspChecker: NewOCSPChecker(logger), logger: logger, } - err := ta.SetAllowedOu(allowedOus) - if err != nil { + switch crlPath { + case "": + logger.Info("no crl_path, skipping CRL checks") + default: + ta.crlChecker, err = NewCRLChecker(crlPath, cache.Empty, logger) + if err != nil { + return nil, err + } + } + + if err := ta.setAllowedOu(allowedOus); err != nil { return nil, err } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index a3996850a2b..7dd6b346aa9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -3,6 +3,7 @@ package apiserver import ( "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -21,21 +22,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - SyncInterval = time.Second * 10 -) +var SyncInterval = time.Second * 10 -const ( - PapiPullKey = "papi:last_pull" -) +const PapiPullKey = "papi:last_pull" -var ( - operationMap = map[string]func(*Message, *Papi, bool) error{ - "decision": DecisionCmd, - "alert": AlertCmd, - "management": ManagementCmd, - } -) +var operationMap = map[string]func(*Message, *Papi, bool) error{ + "decision": DecisionCmd, + "alert": AlertCmd, + "management": ManagementCmd, +} type Header struct { OperationType string `json:"operation_type"` @@ -87,21 +82,21 @@ type PapiPermCheckSuccess struct { } func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) { - logger := log.New() if err := types.ConfigureLogger(logger); err != nil { - return &Papi{}, fmt.Errorf("creating papi logger: %s", err) + return &Papi{}, fmt.Errorf("creating papi logger: %w", err) } + logger.SetLevel(logLevel) papiUrl := *apic.apiClient.PapiURL papiUrl.Path = fmt.Sprintf("%s%s", types.PAPIVersion, types.PAPIPollUrl) + longPollClient, err := longpollclient.NewLongPollClient(longpollclient.LongPollClientConfig{ Url: papiUrl, Logger: logger, HttpClient: apic.apiClient.GetClient(), }) - if err != nil { return &Papi{}, fmt.Errorf("failed to create PAPI client: %w", err) } @@ -132,55 +127,69 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger := p.Logger.WithField("request-id", event.RequestId) logger.Debugf("message received: %+v", event.Data) + message := &Message{} if err := json.Unmarshal([]byte(event.Data), message); err != nil { - return fmt.Errorf("polling papi message format is not compatible: %+v: %s", event.Data, err) + return fmt.Errorf("polling papi message format is not compatible: %+v: %w", event.Data, err) } + if message.Header == nil { - return fmt.Errorf("no header in message, skipping") + return errors.New("no header in message, skipping") } + if message.Header.Source == nil { - return fmt.Errorf("no source user in header message, skipping") + return errors.New("no source user in header message, skipping") } - if operationFunc, ok := operationMap[message.Header.OperationType]; ok { - logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p, sync) - if err != nil { - return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err) - } - } else { + operationFunc, ok := operationMap[message.Header.OperationType] + if !ok { return fmt.Errorf("operation '%s' unknown, continue", message.Header.OperationType) } + + logger.Debugf("Calling operation '%s'", message.Header.OperationType) + + err := operationFunc(message, p, sync) + if err != nil { + return fmt.Errorf("'%s %s failed: %w", message.Header.OperationType, message.Header.OperationCmd, err) + } + return nil } -func (p *Papi) GetPermissions() (PapiPermCheckSuccess, error) { +func (p *Papi) GetPermissions(ctx context.Context) (PapiPermCheckSuccess, error) { httpClient := p.apiClient.GetClient() papiCheckUrl := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) - req, err := http.NewRequest(http.MethodGet, papiCheckUrl, nil) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, papiCheckUrl, nil) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request: %w", err) } + resp, err := httpClient.Do(req) if err != nil { - log.Fatalf("failed to get response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to get response: %w", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { errResp := PapiPermCheckError{} + err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return PapiPermCheckSuccess{}, fmt.Errorf("unable to query PAPI : %s (%d)", errResp.Error, resp.StatusCode) } + respBody := PapiPermCheckSuccess{} + err = json.NewDecoder(resp.Body).Decode(&respBody) if err != nil { - return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response : %s", err) + return PapiPermCheckSuccess{}, fmt.Errorf("failed to decode response: %w", err) } + return respBody, nil } @@ -202,7 +211,7 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { return err } - reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order + reversedEvents := reverse(events) // PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order eventsCount := len(events) p.Logger.Infof("received %d events", eventsCount) @@ -215,38 +224,38 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { } p.Logger.Debugf("finished handling events") - //Don't update the timestamp in DB, as a "real" LAPI might be running - //Worst case, crowdsec will receive a few duplicated events and will discard them + // Don't update the timestamp in DB, as a "real" LAPI might be running + // Worst case, crowdsec will receive a few duplicated events and will discard them return nil } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } - //value doesn't exist, it's first time we're pulling + // value doesn't exist, it's first time we're pulling if lastTimestampStr == nil { binTime, err := lastTimestamp.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) } } else { if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil { - return fmt.Errorf("failed to unmarshal last timestamp: %w", err) + return fmt.Errorf("failed to parse last timestamp: %w", err) } } @@ -254,12 +263,12 @@ func (p *Papi) Pull() error { for event := range p.Client.Start(lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) - //update last timestamp in database + // update last timestamp in database newTime := time.Now().UTC() binTime, err := newTime.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } err = p.handleEvent(event, false) @@ -268,7 +277,7 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } @@ -329,7 +338,7 @@ func (p *Papi) SyncDecisions() error { func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { var ( cache []models.DecisionsDeleteRequestItem = *cacheOrig - send models.DecisionsDeleteRequest + send models.DecisionsDeleteRequest ) bulkSize := 50 @@ -359,7 +368,7 @@ func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { - //we log it here as well, because the return value of func might be discarded + // we log it here as well, because the return value of func might be discarded p.Logger.Errorf("sending deleted decisions to central API: %s", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index ba02034882c..78f5dc9b0fe 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -37,7 +38,13 @@ type forcePull struct { Blocklist *blocklistLink `json:"blocklist,omitempty"` } +type listUnsubscribe struct { + Name string `json:"name"` +} + func DecisionCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "delete": data, err := json.Marshal(message.Data) @@ -59,10 +66,10 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { - return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err) + return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } decisions := make([]*models.Decision, 0) @@ -90,6 +97,8 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { } func AlertCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -126,12 +135,13 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { alert.Scenario = ptr.Of("") alert.Source = &models.Source{} - //if we're setting Source.Scope to types.ConsoleOrigin, it messes up the alert's value + // if we're setting Source.Scope to types.ConsoleOrigin, it messes up the alert's value if len(alert.Decisions) >= 1 { alert.Source.Scope = alert.Decisions[0].Scope alert.Source.Value = alert.Decisions[0].Value } else { log.Warningf("No decision found in alert for Polling API (%s : %s)", message.Header.Source.User, message.Header.Message) + alert.Source.Scope = ptr.Of(types.ConsoleOrigin) alert.Source.Value = &message.Header.Source.User } @@ -146,8 +156,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID) } - //use a different method : alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + // use a different method: alert and/or decision might already be partially present in the database + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { @@ -162,34 +172,69 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } func ManagementCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + if sync { - log.Infof("Ignoring management command from PAPI in sync mode") + p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil } switch message.Header.OperationCmd { + case "blocklist_unsubscribe": + data, err := json.Marshal(message.Data) + if err != nil { + return err + } + + unsubscribeMsg := listUnsubscribe{} + if err := json.Unmarshal(data, &unsubscribeMsg); err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + if unsubscribeMsg.Name == "" { + return fmt.Errorf("message for '%s' contains bad data format: missing blocklist name", message.Header.OperationType) + } + + p.Logger.Infof("Received blocklist_unsubscribe command from PAPI, unsubscribing from blocklist %s", unsubscribeMsg.Name) + + filter := make(map[string][]string) + filter["origin"] = []string{types.ListOrigin} + filter["scenario"] = []string{unsubscribeMsg.Name} + + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) + if err != nil { + return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) + } + + p.Logger.Infof("deleted %d decisions for list %s", len(deletedDecisions), unsubscribeMsg.Name) case "reauth": - log.Infof("Received reauth command from PAPI, resetting token") + p.Logger.Infof("Received reauth command from PAPI, resetting token") p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken() case "force_pull": data, err := json.Marshal(message.Data) if err != nil { return err } + forcePullMsg := forcePull{} + if err := json.Unmarshal(data, &forcePullMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } + ctx := context.TODO() + if forcePullMsg.Blocklist == nil { - log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) + p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") + + err = p.apic.PullTop(ctx, true) if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + return fmt.Errorf("failed to force pull operation: %w", err) } } else { - log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) - err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) + + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ Name: &forcePullMsg.Blocklist.Name, URL: &forcePullMsg.Blocklist.Url, Remediation: &forcePullMsg.Blocklist.Remediation, diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go new file mode 100644 index 00000000000..32aeb7d9a5a --- /dev/null +++ b/pkg/apiserver/usage_metrics_test.go @@ -0,0 +1,388 @@ +package apiserver + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +func TestLPMetrics(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + body string + expectedStatusCode int + expectedResponse string + expectedMetricsCount int + expectedOSName string + expectedOSVersion string + expectedFeatureFlags string + authType string + }{ + { + name: "empty metrics for LP", + body: `{ + }`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing log processor data", + authType: PASSWORD, + }, + { + name: "basic metrics with empty dynamic metrics for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "basic metrics with dynamic metrics for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [{"meta":{"utc_now_timestamp":42, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }, {"meta":{"utc_now_timestamp":43, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "wrong auth type for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing remediation component data", + authType: APIKEY, + }, + { + name: "missing OS field for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedResponse: "", + expectedMetricsCount: 1, + expectedFeatureFlags: "a,b,c", + authType: PASSWORD, + }, + { + name: "missing datasources for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "log_processors.0.datasources in body is required", + authType: PASSWORD, + }, + { + name: "missing feature flags for LP", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedOSName: "foo", + expectedOSVersion: "42", + authType: PASSWORD, + }, + { + name: "missing OS name", + body: ` +{ + "log_processors": [ + { + "version": "1.42", + "os": {"version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"], + "datasources": {"file": 42}, + "hub_items": {} + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "log_processors.0.os.name in body is required", + authType: PASSWORD, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lapi := SetupLAPITest(t, ctx) + + dbClient, err := database.NewClient(ctx, lapi.DBConfig) + if err != nil { + t.Fatalf("unable to create database client: %s", err) + } + + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + + assert.Equal(t, tt.expectedStatusCode, w.Code) + assert.Contains(t, w.Body.String(), tt.expectedResponse) + + machine, _ := dbClient.QueryMachineByID(ctx, "test") + metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") + + assert.Len(t, metrics, tt.expectedMetricsCount) + assert.Equal(t, tt.expectedOSName, machine.Osname) + assert.Equal(t, tt.expectedOSVersion, machine.Osversion) + assert.Equal(t, tt.expectedFeatureFlags, machine.Featureflags) + + if len(metrics) > 0 { + assert.Equal(t, "test", metrics[0].GeneratedBy) + assert.Equal(t, metric.GeneratedType("LP"), metrics[0].GeneratedType) + } + }) + } +} + +func TestRCMetrics(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + body string + expectedStatusCode int + expectedResponse string + expectedMetricsCount int + expectedOSName string + expectedOSVersion string + expectedFeatureFlags string + authType string + }{ + { + name: "empty metrics for RC", + body: `{ + }`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing remediation component data", + authType: APIKEY, + }, + { + name: "basic metrics with empty dynamic metrics for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "basic metrics with dynamic metrics for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [{"meta":{"utc_now_timestamp":42, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }, {"meta":{"utc_now_timestamp":43, "window_size_seconds": 42}, "items": [{"name": "foo", "value": 42, "unit": "bla"}] }], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedResponse: "", + expectedOSName: "foo", + expectedOSVersion: "42", + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "wrong auth type for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: "Missing log processor data", + authType: PASSWORD, + }, + { + name: "missing OS field for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedResponse: "", + expectedMetricsCount: 1, + expectedFeatureFlags: "a,b,c", + authType: APIKEY, + }, + { + name: "missing feature flags for RC", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"name":"foo", "version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [] + } + ] +}`, + expectedStatusCode: http.StatusCreated, + expectedMetricsCount: 1, + expectedOSName: "foo", + expectedOSVersion: "42", + authType: APIKEY, + }, + { + name: "missing OS name", + body: ` +{ + "remediation_components": [ + { + "version": "1.42", + "os": {"version": "42"}, + "utc_startup_timestamp": 42, + "metrics": [], + "feature_flags": ["a", "b", "c"] + } + ] +}`, + expectedStatusCode: http.StatusUnprocessableEntity, + expectedResponse: "remediation_components.0.os.name in body is required", + authType: APIKEY, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lapi := SetupLAPITest(t, ctx) + + dbClient, err := database.NewClient(ctx, lapi.DBConfig) + if err != nil { + t.Fatalf("unable to create database client: %s", err) + } + + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + + assert.Equal(t, tt.expectedStatusCode, w.Code) + assert.Contains(t, w.Body.String(), tt.expectedResponse) + + bouncer, _ := dbClient.SelectBouncerByName(ctx, "test") + metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") + + assert.Len(t, metrics, tt.expectedMetricsCount) + assert.Equal(t, tt.expectedOSName, bouncer.Osname) + assert.Equal(t, tt.expectedOSVersion, bouncer.Osversion) + assert.Equal(t, tt.expectedFeatureFlags, bouncer.Featureflags) + + if len(metrics) > 0 { + assert.Equal(t, "test", metrics[0].GeneratedBy) + assert.Equal(t, metric.GeneratedType("RC"), metrics[0].GeneratedType) + } + }) + } +} diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go index ec7e7bef3b6..30784b23db0 100644 --- a/pkg/appsec/appsec.go +++ b/pkg/appsec/appsec.go @@ -1,17 +1,20 @@ package appsec import ( + "errors" "fmt" + "net/http" "os" "regexp" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" - log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" ) type Hook struct { @@ -30,8 +33,13 @@ const ( hookOnMatch ) -func (h *Hook) Build(hookStage int) error { +const ( + BanRemediation = "ban" + CaptchaRemediation = "captcha" + AllowRemediation = "allow" +) +func (h *Hook) Build(hookStage int) error { ctx := map[string]interface{}{} switch hookStage { case hookOnLoad: @@ -45,7 +53,7 @@ func (h *Hook) Build(hookStage int) error { } opts := exprhelpers.GetExprOptions(ctx) if h.Filter != "" { - program, err := expr.Compile(h.Filter, opts...) //FIXME: opts + program, err := expr.Compile(h.Filter, opts...) // FIXME: opts if err != nil { return fmt.Errorf("unable to compile filter %s : %w", h.Filter, err) } @@ -62,12 +70,13 @@ func (h *Hook) Build(hookStage int) error { } type AppsecTempResponse struct { - InBandInterrupt bool - OutOfBandInterrupt bool - Action string //allow, deny, captcha, log - HTTPResponseCode int - SendEvent bool //do we send an internal event on rule match - SendAlert bool //do we send an alert on rule match + InBandInterrupt bool + OutOfBandInterrupt bool + Action string // allow, deny, captcha, log + UserHTTPResponseCode int // The response code to send to the user + BouncerHTTPResponseCode int // The response code to send to the remediation component + SendEvent bool // do we send an internal event on rule match + SendAlert bool // do we send an alert on rule match } type AppsecSubEngineOpts struct { @@ -83,7 +92,7 @@ type AppsecRuntimeConfig struct { InBandRules []AppsecCollection DefaultRemediation string - RemediationByTag map[string]string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + RemediationByTag map[string]string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME RemediationById map[int]string CompiledOnLoad []Hook CompiledPreEval []Hook @@ -91,56 +100,57 @@ type AppsecRuntimeConfig struct { CompiledOnMatch []Hook CompiledVariablesTracking []*regexp.Regexp Config *AppsecConfig - //CorazaLogger debuglog.Logger + // CorazaLogger debuglog.Logger - //those are ephemeral, created/destroyed with every req - OutOfBandTx ExtendedTransaction //is it a good idea ? - InBandTx ExtendedTransaction //is it a good idea ? + // those are ephemeral, created/destroyed with every req + OutOfBandTx ExtendedTransaction // is it a good idea ? + InBandTx ExtendedTransaction // is it a good idea ? Response AppsecTempResponse - //should we store matched rules here ? + // should we store matched rules here ? Logger *log.Entry - //Set by on_load to ignore some rules on loading + // Set by on_load to ignore some rules on loading DisabledInBandRuleIds []int - DisabledInBandRulesTags []string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + DisabledInBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME DisabledOutOfBandRuleIds []int - DisabledOutOfBandRulesTags []string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + DisabledOutOfBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME } type AppsecConfig struct { - Name string `yaml:"name"` - OutOfBandRules []string `yaml:"outofband_rules"` - InBandRules []string `yaml:"inband_rules"` - DefaultRemediation string `yaml:"default_remediation"` - DefaultPassAction string `yaml:"default_pass_action"` - BlockedHTTPCode int `yaml:"blocked_http_code"` - PassedHTTPCode int `yaml:"passed_http_code"` - OnLoad []Hook `yaml:"on_load"` - PreEval []Hook `yaml:"pre_eval"` - PostEval []Hook `yaml:"post_eval"` - OnMatch []Hook `yaml:"on_match"` - VariablesTracking []string `yaml:"variables_tracking"` - InbandOptions AppsecSubEngineOpts `yaml:"inband_options"` - OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"` + Name string `yaml:"name"` + OutOfBandRules []string `yaml:"outofband_rules"` + InBandRules []string `yaml:"inband_rules"` + DefaultRemediation string `yaml:"default_remediation"` + DefaultPassAction string `yaml:"default_pass_action"` + BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` // returned to the bouncer + BouncerPassedHTTPCode int `yaml:"passed_http_code"` // returned to the bouncer + UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` // returned to the user + UserPassedHTTPCode int `yaml:"user_passed_http_code"` // returned to the user + + OnLoad []Hook `yaml:"on_load"` + PreEval []Hook `yaml:"pre_eval"` + PostEval []Hook `yaml:"post_eval"` + OnMatch []Hook `yaml:"on_match"` + VariablesTracking []string `yaml:"variables_tracking"` + InbandOptions AppsecSubEngineOpts `yaml:"inband_options"` + OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"` LogLevel *log.Level `yaml:"log_level"` Logger *log.Entry `yaml:"-"` } func (w *AppsecRuntimeConfig) ClearResponse() { - w.Logger.Debugf("#-> %p", w) w.Response = AppsecTempResponse{} - w.Logger.Debugf("-> %p", w.Config) w.Response.Action = w.Config.DefaultPassAction - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode + w.Response.BouncerHTTPResponseCode = w.Config.BouncerPassedHTTPCode + w.Response.UserHTTPResponseCode = w.Config.UserPassedHTTPCode w.Response.SendEvent = true w.Response.SendAlert = true } func (wc *AppsecConfig) LoadByPath(file string) error { - wc.Logger.Debugf("loading config %s", file) yamlFile, err := os.ReadFile(file) @@ -153,7 +163,7 @@ func (wc *AppsecConfig) LoadByPath(file string) error { } if wc.Name == "" { - return fmt.Errorf("name cannot be empty") + return errors.New("name cannot be empty") } if wc.LogLevel == nil { lvl := wc.Logger.Logger.GetLevel() @@ -165,19 +175,13 @@ func (wc *AppsecConfig) LoadByPath(file string) error { } func (wc *AppsecConfig) Load(configName string) error { - appsecConfigs := hub.GetItemMap(cwhub.APPSEC_CONFIGS) + item := hub.GetItem(cwhub.APPSEC_CONFIGS, configName) - for _, hubAppsecConfigItem := range appsecConfigs { - if !hubAppsecConfigItem.State.Installed { - continue - } - if hubAppsecConfigItem.Name != configName { - continue - } - wc.Logger.Infof("loading %s", hubAppsecConfigItem.State.LocalPath) - err := wc.LoadByPath(hubAppsecConfigItem.State.LocalPath) + if item != nil && item.State.Installed { + wc.Logger.Infof("loading %s", item.State.LocalPath) + err := wc.LoadByPath(item.State.LocalPath) if err != nil { - return fmt.Errorf("unable to load appsec-config %s : %s", hubAppsecConfigItem.State.LocalPath, err) + return fmt.Errorf("unable to load appsec-config %s : %s", item.State.LocalPath, err) } return nil } @@ -191,30 +195,41 @@ func (wc *AppsecConfig) GetDataDir() string { func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { ret := &AppsecRuntimeConfig{Logger: wc.Logger.WithField("component", "appsec_runtime_config")} - //set the defaults - switch wc.DefaultRemediation { - case "": - wc.DefaultRemediation = "ban" - case "ban", "captcha", "log": - //those are the officially supported remediation(s) - default: - wc.Logger.Warningf("default '%s' remediation of %s is none of [ban,captcha,log] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name) + + if wc.BouncerBlockedHTTPCode == 0 { + wc.BouncerBlockedHTTPCode = http.StatusForbidden } - if wc.BlockedHTTPCode == 0 { - wc.BlockedHTTPCode = 403 + if wc.BouncerPassedHTTPCode == 0 { + wc.BouncerPassedHTTPCode = http.StatusOK } - if wc.PassedHTTPCode == 0 { - wc.PassedHTTPCode = 200 + + if wc.UserBlockedHTTPCode == 0 { + wc.UserBlockedHTTPCode = http.StatusForbidden + } + if wc.UserPassedHTTPCode == 0 { + wc.UserPassedHTTPCode = http.StatusOK } if wc.DefaultPassAction == "" { - wc.DefaultPassAction = "allow" + wc.DefaultPassAction = AllowRemediation + } + if wc.DefaultRemediation == "" { + wc.DefaultRemediation = BanRemediation + } + + // set the defaults + switch wc.DefaultRemediation { + case BanRemediation, CaptchaRemediation, AllowRemediation: + // those are the officially supported remediation(s) + default: + wc.Logger.Warningf("default '%s' remediation of %s is none of [%s,%s,%s] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name, BanRemediation, CaptchaRemediation, AllowRemediation) } + ret.Name = wc.Name ret.Config = wc ret.DefaultRemediation = wc.DefaultRemediation wc.Logger.Tracef("Loading config %+v", wc) - //load rules + // load rules for _, rule := range wc.OutOfBandRules { wc.Logger.Infof("loading outofband rule %s", rule) collections, err := LoadCollection(rule, wc.Logger.WithField("component", "appsec_collection_loader")) @@ -236,8 +251,11 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { wc.Logger.Infof("Loaded %d inband rules", len(ret.InBandRules)) - //load hooks + // load hooks for _, hook := range wc.OnLoad { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for on_load hook : %s", hook.OnSuccess) + } err := hook.Build(hookOnLoad) if err != nil { return nil, fmt.Errorf("unable to build on_load hook : %s", err) @@ -246,6 +264,9 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { } for _, hook := range wc.PreEval { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for pre_eval hook : %s", hook.OnSuccess) + } err := hook.Build(hookPreEval) if err != nil { return nil, fmt.Errorf("unable to build pre_eval hook : %s", err) @@ -254,6 +275,9 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { } for _, hook := range wc.PostEval { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for post_eval hook : %s", hook.OnSuccess) + } err := hook.Build(hookPostEval) if err != nil { return nil, fmt.Errorf("unable to build post_eval hook : %s", err) @@ -262,6 +286,9 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { } for _, hook := range wc.OnMatch { + if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for on_match hook : %s", hook.OnSuccess) + } err := hook.Build(hookOnMatch) if err != nil { return nil, fmt.Errorf("unable to build on_match hook : %s", err) @@ -269,7 +296,7 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { ret.CompiledOnMatch = append(ret.CompiledOnMatch, hook) } - //variable tracking + // variable tracking for _, variable := range wc.VariablesTracking { compiledVariableRule, err := regexp.Compile(variable) if err != nil { @@ -281,6 +308,7 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { } func (w *AppsecRuntimeConfig) ProcessOnLoadRules() error { + has_match := false for _, rule := range w.CompiledOnLoad { if rule.FilterExpr != nil { output, err := exprhelpers.Run(rule.FilterExpr, GetOnLoadEnv(w), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -297,6 +325,7 @@ func (w *AppsecRuntimeConfig) ProcessOnLoadRules() error { w.Logger.Errorf("Filter must return a boolean, can't filter") continue } + has_match = true } for _, applyExpr := range rule.ApplyExpr { o, err := exprhelpers.Run(applyExpr, GetOnLoadEnv(w), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -311,12 +340,15 @@ func (w *AppsecRuntimeConfig) ProcessOnLoadRules() error { default: } } + if has_match && rule.OnSuccess == "break" { + break + } } return nil } func (w *AppsecRuntimeConfig) ProcessOnMatchRules(request *ParsedRequest, evt types.Event) error { - + has_match := false for _, rule := range w.CompiledOnMatch { if rule.FilterExpr != nil { output, err := exprhelpers.Run(rule.FilterExpr, GetOnMatchEnv(w, request, evt), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -333,6 +365,7 @@ func (w *AppsecRuntimeConfig) ProcessOnMatchRules(request *ParsedRequest, evt ty w.Logger.Errorf("Filter must return a boolean, can't filter") continue } + has_match = true } for _, applyExpr := range rule.ApplyExpr { o, err := exprhelpers.Run(applyExpr, GetOnMatchEnv(w, request, evt), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -347,12 +380,15 @@ func (w *AppsecRuntimeConfig) ProcessOnMatchRules(request *ParsedRequest, evt ty default: } } + if has_match && rule.OnSuccess == "break" { + break + } } return nil } func (w *AppsecRuntimeConfig) ProcessPreEvalRules(request *ParsedRequest) error { - w.Logger.Debugf("processing %d pre_eval rules", len(w.CompiledPreEval)) + has_match := false for _, rule := range w.CompiledPreEval { if rule.FilterExpr != nil { output, err := exprhelpers.Run(rule.FilterExpr, GetPreEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -369,6 +405,7 @@ func (w *AppsecRuntimeConfig) ProcessPreEvalRules(request *ParsedRequest) error w.Logger.Errorf("Filter must return a boolean, can't filter") continue } + has_match = true } // here means there is no filter or the filter matched for _, applyExpr := range rule.ApplyExpr { @@ -384,12 +421,16 @@ func (w *AppsecRuntimeConfig) ProcessPreEvalRules(request *ParsedRequest) error default: } } + if has_match && rule.OnSuccess == "break" { + break + } } return nil } func (w *AppsecRuntimeConfig) ProcessPostEvalRules(request *ParsedRequest) error { + has_match := false for _, rule := range w.CompiledPostEval { if rule.FilterExpr != nil { output, err := exprhelpers.Run(rule.FilterExpr, GetPostEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) @@ -406,11 +447,11 @@ func (w *AppsecRuntimeConfig) ProcessPostEvalRules(request *ParsedRequest) error w.Logger.Errorf("Filter must return a boolean, can't filter") continue } + has_match = true } // here means there is no filter or the filter matched for _, applyExpr := range rule.ApplyExpr { o, err := exprhelpers.Run(applyExpr, GetPostEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) - if err != nil { w.Logger.Errorf("unable to apply appsec post_eval expr: %s", err) continue @@ -423,6 +464,9 @@ func (w *AppsecRuntimeConfig) ProcessPostEvalRules(request *ParsedRequest) error default: } } + if has_match && rule.OnSuccess == "break" { + break + } } return nil @@ -551,29 +595,15 @@ func (w *AppsecRuntimeConfig) SetActionByName(name string, action string) error } func (w *AppsecRuntimeConfig) SetAction(action string) error { - //log.Infof("setting to %s", action) + // log.Infof("setting to %s", action) w.Logger.Debugf("setting action to %s", action) - switch action { - case "allow": - w.Response.Action = action - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode - //@tko how should we handle this ? it seems bouncer only understand bans, but it might be misleading ? - case "deny", "ban", "block": - w.Response.Action = "ban" - case "log": - w.Response.Action = action - w.Response.HTTPResponseCode = w.Config.PassedHTTPCode - case "captcha": - w.Response.Action = action - default: - w.Response.Action = action - } + w.Response.Action = action return nil } func (w *AppsecRuntimeConfig) SetHTTPCode(code int) error { w.Logger.Debugf("setting http code to %d", code) - w.Response.HTTPResponseCode = code + w.Response.UserHTTPResponseCode = code return nil } @@ -582,24 +612,23 @@ type BodyResponse struct { HTTPStatus int `json:"http_status"` } -func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) BodyResponse { - resp := BodyResponse{} - //if there is no interrupt, we should allow with default code - if !response.InBandInterrupt { - resp.Action = w.Config.DefaultPassAction - resp.HTTPStatus = w.Config.PassedHTTPCode - return resp - } - resp.Action = response.Action - if resp.Action == "" { - resp.Action = w.Config.DefaultRemediation - } - logger.Debugf("action is %s", resp.Action) +func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) (int, BodyResponse) { + var bouncerStatusCode int - resp.HTTPStatus = response.HTTPResponseCode - if resp.HTTPStatus == 0 { - resp.HTTPStatus = w.Config.BlockedHTTPCode + resp := BodyResponse{Action: response.Action} + if response.Action == AllowRemediation { + resp.HTTPStatus = w.Config.UserPassedHTTPCode + bouncerStatusCode = w.Config.BouncerPassedHTTPCode + } else { // ban, captcha and anything else + resp.HTTPStatus = response.UserHTTPResponseCode + if resp.HTTPStatus == 0 { + resp.HTTPStatus = w.Config.UserBlockedHTTPCode + } + bouncerStatusCode = response.BouncerHTTPResponseCode + if bouncerStatusCode == 0 { + bouncerStatusCode = w.Config.BouncerBlockedHTTPCode + } } - logger.Debugf("http status is %d", resp.HTTPStatus) - return resp + + return bouncerStatusCode, resp } diff --git a/pkg/appsec/appsec_rule/appsec_rule.go b/pkg/appsec/appsec_rule/appsec_rule.go index 289405ef161..136d8b11cb7 100644 --- a/pkg/appsec/appsec_rule/appsec_rule.go +++ b/pkg/appsec/appsec_rule/appsec_rule.go @@ -1,6 +1,7 @@ package appsec_rule import ( + "errors" "fmt" ) @@ -48,15 +49,15 @@ type CustomRule struct { func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) { if v.Zones == nil && v.And == nil && v.Or == nil { - return "", nil, fmt.Errorf("no zones defined") + return "", nil, errors.New("no zones defined") } if v.Match.Type == "" && v.And == nil && v.Or == nil { - return "", nil, fmt.Errorf("no match type defined") + return "", nil, errors.New("no match type defined") } if v.Match.Value == "" && v.And == nil && v.Or == nil { - return "", nil, fmt.Errorf("no match value defined") + return "", nil, errors.New("no match value defined") } switch ruleType { diff --git a/pkg/appsec/appsec_rule/modsecurity.go b/pkg/appsec/appsec_rule/modsecurity.go index 0b117cd773d..135ba525e8e 100644 --- a/pkg/appsec/appsec_rule/modsecurity.go +++ b/pkg/appsec/appsec_rule/modsecurity.go @@ -1,6 +1,7 @@ package appsec_rule import ( + "errors" "fmt" "hash/fnv" "strings" @@ -10,29 +11,41 @@ type ModsecurityRule struct { ids []uint32 } -var zonesMap map[string]string = map[string]string{ - "ARGS": "ARGS_GET", - "ARGS_NAMES": "ARGS_GET_NAMES", - "BODY_ARGS": "ARGS_POST", - "BODY_ARGS_NAMES": "ARGS_POST_NAMES", - "HEADERS_NAMES": "REQUEST_HEADERS_NAMES", - "HEADERS": "REQUEST_HEADERS", - "METHOD": "REQUEST_METHOD", - "PROTOCOL": "REQUEST_PROTOCOL", - "URI": "REQUEST_URI", - "RAW_BODY": "REQUEST_BODY", - "FILENAMES": "FILES", +var zonesMap = map[string]string{ + "ARGS": "ARGS_GET", + "ARGS_NAMES": "ARGS_GET_NAMES", + "BODY_ARGS": "ARGS_POST", + "BODY_ARGS_NAMES": "ARGS_POST_NAMES", + "COOKIES": "REQUEST_COOKIES", + "COOKIES_NAMES": "REQUEST_COOKIES_NAMES", + "FILES": "FILES", + "FILES_NAMES": "FILES_NAMES", + "FILES_TOTAL_SIZE": "FILES_COMBINED_SIZE", + "HEADERS_NAMES": "REQUEST_HEADERS_NAMES", + "HEADERS": "REQUEST_HEADERS", + "METHOD": "REQUEST_METHOD", + "PROTOCOL": "REQUEST_PROTOCOL", + "URI": "REQUEST_FILENAME", + "URI_FULL": "REQUEST_URI", + "RAW_BODY": "REQUEST_BODY", + "FILENAMES": "FILES", } -var transformMap map[string]string = map[string]string{ +var transformMap = map[string]string{ "lowercase": "t:lowercase", "uppercase": "t:uppercase", "b64decode": "t:base64Decode", - "hexdecode": "t:hexDecode", - "length": "t:length", + //"hexdecode": "t:hexDecode", -> not supported by coraza + "length": "t:length", + "urldecode": "t:urlDecode", + "trim": "t:trim", + "normalize_path": "t:normalizePath", + "normalizepath": "t:normalizePath", + "htmlentitydecode": "t:htmlEntityDecode", + "html_entity_decode": "t:htmlEntityDecode", } -var matchMap map[string]string = map[string]string{ +var matchMap = map[string]string{ "regex": "@rx", "equals": "@streq", "startsWith": "@beginsWith", @@ -47,7 +60,7 @@ var matchMap map[string]string = map[string]string{ "eq": "@eq", } -var bodyTypeMatch map[string]string = map[string]string{ +var bodyTypeMatch = map[string]string{ "json": "JSON", "xml": "XML", "multipart": "MULTIPART", @@ -55,9 +68,7 @@ var bodyTypeMatch map[string]string = map[string]string{ } func (m *ModsecurityRule) Build(rule *CustomRule, appsecRuleName string) (string, []uint32, error) { - rules, err := m.buildRules(rule, appsecRuleName, false, 0, 0) - if err != nil { return "", nil, err } @@ -87,7 +98,7 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an ret := make([]string, 0) if len(rule.And) != 0 && len(rule.Or) != 0 { - return nil, fmt.Errorf("cannot have both 'and' and 'or' in the same rule") + return nil, errors.New("cannot have both 'and' and 'or' in the same rule") } if rule.And != nil { @@ -154,15 +165,15 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an r.WriteByte(' ') if rule.Match.Type != "" { - if match, ok := matchMap[rule.Match.Type]; ok { - prefix := "" - if rule.Match.Not { - prefix = "!" - } - r.WriteString(fmt.Sprintf(`"%s%s %s"`, prefix, match, rule.Match.Value)) - } else { + match, ok := matchMap[rule.Match.Type] + if !ok { return nil, fmt.Errorf("unknown match type '%s'", rule.Match.Type) } + prefix := "" + if rule.Match.Not { + prefix = "!" + } + r.WriteString(fmt.Sprintf(`"%s%s %s"`, prefix, match, rule.Match.Value)) } //Should phase:2 be configurable? @@ -174,20 +185,20 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an continue } r.WriteByte(',') - if mappedTransform, ok := transformMap[transform]; ok { - r.WriteString(mappedTransform) - } else { + mappedTransform, ok := transformMap[transform] + if !ok { return nil, fmt.Errorf("unknown transform '%s'", transform) } + r.WriteString(mappedTransform) } } if rule.BodyType != "" { - if mappedBodyType, ok := bodyTypeMatch[rule.BodyType]; ok { - r.WriteString(fmt.Sprintf(",ctl:requestBodyProcessor=%s", mappedBodyType)) - } else { + mappedBodyType, ok := bodyTypeMatch[rule.BodyType] + if !ok { return nil, fmt.Errorf("unknown body type '%s'", rule.BodyType) } + r.WriteString(fmt.Sprintf(",ctl:requestBodyProcessor=%s", mappedBodyType)) } if and { diff --git a/pkg/appsec/appsec_rules_collection.go b/pkg/appsec/appsec_rules_collection.go index 2024673c330..d283f95cb19 100644 --- a/pkg/appsec/appsec_rules_collection.go +++ b/pkg/appsec/appsec_rules_collection.go @@ -6,10 +6,10 @@ import ( "path/filepath" "strings" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - - log "github.com/sirupsen/logrus" ) type AppsecCollection struct { @@ -29,11 +29,11 @@ type AppsecCollectionConfig struct { SecLangRules []string `yaml:"seclang_rules"` Rules []appsec_rule.CustomRule `yaml:"rules"` - Labels map[string]interface{} `yaml:"labels"` //Labels is K:V list aiming at providing context the overflow + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow - Data interface{} `yaml:"data"` //Ignore it - hash string `yaml:"-"` - version string `yaml:"-"` + Data interface{} `yaml:"data"` // Ignore it + hash string + version string } type RulesDetails struct { @@ -51,9 +51,7 @@ func LoadCollection(pattern string, logger *log.Entry) ([]AppsecCollection, erro ret := make([]AppsecCollection, 0) for _, appsecRule := range appsecRules { - tmpMatch, err := exprhelpers.Match(pattern, appsecRule.Name) - if err != nil { logger.Errorf("unable to match %s with %s : %s", appsecRule.Name, pattern, err) continue @@ -110,7 +108,7 @@ func LoadCollection(pattern string, logger *log.Entry) ([]AppsecCollection, erro logger.Debugf("Adding rule %s", strRule) appsecCol.Rules = append(appsecCol.Rules, strRule) - //We only take the first id, as it's the one of the "main" rule + // We only take the first id, as it's the one of the "main" rule if _, ok := AppsecRulesDetails[int(rulesId[0])]; !ok { AppsecRulesDetails[int(rulesId[0])] = RulesDetails{ LogLevel: log.InfoLevel, diff --git a/pkg/appsec/coraza_logger.go b/pkg/appsec/coraza_logger.go index 372a0098ecc..d2c1612cbd7 100644 --- a/pkg/appsec/coraza_logger.go +++ b/pkg/appsec/coraza_logger.go @@ -4,11 +4,12 @@ import ( "fmt" "io" - dbg "github.com/crowdsecurity/coraza/v3/debuglog" log "github.com/sirupsen/logrus" + + dbg "github.com/crowdsecurity/coraza/v3/debuglog" ) -var DebugRules map[int]bool = map[int]bool{} +var DebugRules = map[int]bool{} func SetRuleDebug(id int, debug bool) { DebugRules[id] = debug @@ -18,6 +19,7 @@ func GetRuleDebug(id int) bool { if val, ok := DebugRules[id]; ok { return val } + return false } @@ -60,7 +62,9 @@ func (e *crzLogEvent) Str(key, val string) dbg.Event { if e.muted { return e } + e.fields[key] = val + return e } @@ -68,7 +72,9 @@ func (e *crzLogEvent) Err(err error) dbg.Event { if e.muted { return e } + e.fields["error"] = err + return e } @@ -76,22 +82,25 @@ func (e *crzLogEvent) Bool(key string, b bool) dbg.Event { if e.muted { return e } + e.fields[key] = b + return e } func (e *crzLogEvent) Int(key string, i int) dbg.Event { if e.muted { - //this allows us to have per-rule debug logging - if key == "rule_id" && GetRuleDebug(i) { - e.muted = false - e.fields = map[string]interface{}{} - e.level = log.DebugLevel - } else { + if key != "rule_id" || !GetRuleDebug(i) { return e } + // this allows us to have per-rule debug logging + e.muted = false + e.fields = map[string]interface{}{} + e.level = log.DebugLevel } + e.fields[key] = i + return e } @@ -99,7 +108,9 @@ func (e *crzLogEvent) Uint(key string, i uint) dbg.Event { if e.muted { return e } + e.fields[key] = i + return e } @@ -107,7 +118,9 @@ func (e *crzLogEvent) Stringer(key string, val fmt.Stringer) dbg.Event { if e.muted { return e } + e.fields[key] = val + return e } @@ -121,74 +134,84 @@ type crzLogger struct { logLevel log.Level } -func NewCrzLogger(logger *log.Entry) crzLogger { - return crzLogger{logger: logger, logLevel: logger.Logger.GetLevel()} +func NewCrzLogger(logger *log.Entry) *crzLogger { + return &crzLogger{logger: logger, logLevel: logger.Logger.GetLevel()} } -func (c crzLogger) NewMutedEvt(lvl log.Level) dbg.Event { +func (c *crzLogger) NewMutedEvt(lvl log.Level) dbg.Event { return &crzLogEvent{muted: true, logger: c.logger, level: lvl} } -func (c crzLogger) NewEvt(lvl log.Level) dbg.Event { + +func (c *crzLogger) NewEvt(lvl log.Level) dbg.Event { evt := &crzLogEvent{fields: map[string]interface{}{}, logger: c.logger, level: lvl} + if c.defaultFields != nil { for k, v := range c.defaultFields { evt.fields[k] = v } } + return evt } -func (c crzLogger) WithOutput(w io.Writer) dbg.Logger { +func (c *crzLogger) WithOutput(w io.Writer) dbg.Logger { return c } -func (c crzLogger) WithLevel(lvl dbg.Level) dbg.Logger { +func (c *crzLogger) WithLevel(lvl dbg.Level) dbg.Logger { c.logLevel = log.Level(lvl) c.logger.Logger.SetLevel(c.logLevel) + return c } -func (c crzLogger) With(fs ...dbg.ContextField) dbg.Logger { - var e dbg.Event = c.NewEvt(c.logLevel) +func (c *crzLogger) With(fs ...dbg.ContextField) dbg.Logger { + e := c.NewEvt(c.logLevel) for _, f := range fs { e = f(e) } + c.defaultFields = e.(*crzLogEvent).fields + return c } -func (c crzLogger) Trace() dbg.Event { +func (c *crzLogger) Trace() dbg.Event { if c.logLevel < log.TraceLevel { return c.NewMutedEvt(log.TraceLevel) } + return c.NewEvt(log.TraceLevel) } -func (c crzLogger) Debug() dbg.Event { +func (c *crzLogger) Debug() dbg.Event { if c.logLevel < log.DebugLevel { return c.NewMutedEvt(log.DebugLevel) - } + return c.NewEvt(log.DebugLevel) } -func (c crzLogger) Info() dbg.Event { +func (c *crzLogger) Info() dbg.Event { if c.logLevel < log.InfoLevel { return c.NewMutedEvt(log.InfoLevel) } + return c.NewEvt(log.InfoLevel) } -func (c crzLogger) Warn() dbg.Event { +func (c *crzLogger) Warn() dbg.Event { if c.logLevel < log.WarnLevel { return c.NewMutedEvt(log.WarnLevel) } + return c.NewEvt(log.WarnLevel) } -func (c crzLogger) Error() dbg.Event { +func (c *crzLogger) Error() dbg.Event { if c.logLevel < log.ErrorLevel { return c.NewMutedEvt(log.ErrorLevel) } + return c.NewEvt(log.ErrorLevel) } diff --git a/pkg/appsec/loader.go b/pkg/appsec/loader.go index 86c1dc0a80e..c724010cec2 100644 --- a/pkg/appsec/loader.go +++ b/pkg/appsec/loader.go @@ -3,27 +3,22 @@ package appsec import ( "os" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" + + "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -var appsecRules map[string]AppsecCollectionConfig = make(map[string]AppsecCollectionConfig) //FIXME: would probably be better to have a struct for this +var appsecRules = make(map[string]AppsecCollectionConfig) // FIXME: would probably be better to have a struct for this -var hub *cwhub.Hub //FIXME: this is a temporary hack to make the hub available in the package +var hub *cwhub.Hub // FIXME: this is a temporary hack to make the hub available in the package func LoadAppsecRules(hubInstance *cwhub.Hub) error { - hub = hubInstance appsecRules = make(map[string]AppsecCollectionConfig) - for _, hubAppsecRuleItem := range hub.GetItemMap(cwhub.APPSEC_RULES) { - if !hubAppsecRuleItem.State.Installed { - continue - } - + for _, hubAppsecRuleItem := range hub.GetInstalledByType(cwhub.APPSEC_RULES, false) { content, err := os.ReadFile(hubAppsecRuleItem.State.LocalPath) - if err != nil { log.Warnf("unable to read file %s : %s", hubAppsecRuleItem.State.LocalPath, err) continue @@ -32,9 +27,8 @@ func LoadAppsecRules(hubInstance *cwhub.Hub) error { var rule AppsecCollectionConfig err = yaml.UnmarshalStrict(content, &rule) - if err != nil { - log.Warnf("unable to unmarshal file %s : %s", hubAppsecRuleItem.State.LocalPath, err) + log.Warnf("unable to parse file %s : %s", hubAppsecRuleItem.State.LocalPath, err) continue } diff --git a/pkg/appsec/query_utils.go b/pkg/appsec/query_utils.go new file mode 100644 index 00000000000..0c886e0ea51 --- /dev/null +++ b/pkg/appsec/query_utils.go @@ -0,0 +1,78 @@ +package appsec + +// This file is mostly stolen from net/url package, but with some modifications to allow less strict parsing of query strings + +import ( + "net/url" + "strings" +) + +// parseQuery and parseQuery are copied net/url package, but allow semicolon in values +func ParseQuery(query string) url.Values { + m := make(url.Values) + parseQuery(m, query) + return m +} + +func parseQuery(m url.Values, query string) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + //for now we'll just ignore the errors, but ideally we want to fire some "internal" rules when we see invalid query strings + key = unescape(key) + value = unescape(value) + m[key] = append(m[key], value) + } +} + +func hexDigitToByte(digit byte) (byte, bool) { + switch { + case digit >= '0' && digit <= '9': + return digit - '0', true + case digit >= 'a' && digit <= 'f': + return digit - 'a' + 10, true + case digit >= 'A' && digit <= 'F': + return digit - 'A' + 10, true + default: + return 0, false + } +} + +func unescape(input string) string { + ilen := len(input) + res := strings.Builder{} + res.Grow(ilen) + for i := 0; i < ilen; i++ { + ci := input[i] + if ci == '+' { + res.WriteByte(' ') + continue + } + if ci == '%' { + if i+2 >= ilen { + res.WriteByte(ci) + continue + } + hi, ok := hexDigitToByte(input[i+1]) + if !ok { + res.WriteByte(ci) + continue + } + lo, ok := hexDigitToByte(input[i+2]) + if !ok { + res.WriteByte(ci) + continue + } + res.WriteByte(hi<<4 | lo) + i += 2 + continue + } + res.WriteByte(ci) + } + return res.String() +} diff --git a/pkg/appsec/query_utils_test.go b/pkg/appsec/query_utils_test.go new file mode 100644 index 00000000000..2ad7927968d --- /dev/null +++ b/pkg/appsec/query_utils_test.go @@ -0,0 +1,207 @@ +package appsec + +import ( + "net/url" + "reflect" + "testing" +) + +func TestParseQuery(t *testing.T) { + tests := []struct { + name string + query string + expected url.Values + }{ + { + name: "Simple query", + query: "foo=bar", + expected: url.Values{ + "foo": []string{"bar"}, + }, + }, + { + name: "Multiple values", + query: "foo=bar&foo=baz", + expected: url.Values{ + "foo": []string{"bar", "baz"}, + }, + }, + { + name: "Empty value", + query: "foo=", + expected: url.Values{ + "foo": []string{""}, + }, + }, + { + name: "Empty key", + query: "=bar", + expected: url.Values{ + "": []string{"bar"}, + }, + }, + { + name: "Empty query", + query: "", + expected: url.Values{}, + }, + { + name: "Multiple keys", + query: "foo=bar&baz=qux", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + }, + }, + { + name: "Multiple keys with empty value", + query: "foo=bar&baz=qux&quux=", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + "quux": []string{""}, + }, + }, + { + name: "Multiple keys with empty value and empty key", + query: "foo=bar&baz=qux&quux=&=quuz", + expected: url.Values{ + "foo": []string{"bar"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz", + expected: url.Values{ + "foo": []string{"bar", "baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand and equals", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz&foo=bar%3Dbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz", "bar=baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "Multiple keys with empty value and empty key and multiple values and escaped characters and semicolon and ampersand and equals and question mark", + query: "foo=bar&baz=qux&quux=&=quuz&foo=baz&foo=bar%20baz&foo=bar%3Bbaz&foo=bar%26baz&foo=bar%3Dbaz&foo=bar%3Fbaz", + expected: url.Values{ + "foo": []string{"bar", "baz", "bar baz", "bar;baz", "bar&baz", "bar=baz", "bar?baz"}, + "baz": []string{"qux"}, + "quux": []string{""}, + "": []string{"quuz"}, + }, + }, + { + name: "keys with escaped characters", + query: "foo=ba;r&baz=qu;;x&quux=x\\&ww&xx=qu?uz&", + expected: url.Values{ + "foo": []string{"ba;r"}, + "baz": []string{"qu;;x"}, + "quux": []string{"x\\"}, + "ww": []string{""}, + "xx": []string{"qu?uz"}, + }, + }, + { + name: "hexadecimal characters", + query: "foo=bar%20baz", + expected: url.Values{ + "foo": []string{"bar baz"}, + }, + }, + { + name: "hexadecimal characters upper and lower case", + query: "foo=Ba%42%42&bar=w%2f%2F", + expected: url.Values{ + "foo": []string{"BaBB"}, + "bar": []string{"w//"}, + }, + }, + { + name: "hexadecimal characters with invalid characters", + query: "foo=bar%20baz%2", + expected: url.Values{ + "foo": []string{"bar baz%2"}, + }, + }, + { + name: "hexadecimal characters with invalid hex characters", + query: "foo=bar%xx", + expected: url.Values{ + "foo": []string{"bar%xx"}, + }, + }, + { + name: "hexadecimal characters with invalid 2nd hex character", + query: "foo=bar%2x", + expected: url.Values{ + "foo": []string{"bar%2x"}, + }, + }, + { + name: "url +", + query: "foo=bar+x", + expected: url.Values{ + "foo": []string{"bar x"}, + }, + }, + { + name: "url &&", + query: "foo=bar&&lol=bur", + expected: url.Values{ + "foo": []string{"bar"}, + "lol": []string{"bur"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res := ParseQuery(test.query) + if !reflect.DeepEqual(res, test.expected) { + t.Fatalf("unexpected result: %v", res) + } + }) + } +} diff --git a/pkg/appsec/request.go b/pkg/appsec/request.go index 6d472e8afae..ccd7a9f9cc8 100644 --- a/pkg/appsec/request.go +++ b/pkg/appsec/request.go @@ -12,16 +12,16 @@ import ( "regexp" "github.com/google/uuid" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) const ( - URIHeaderName = "X-Crowdsec-Appsec-Uri" - VerbHeaderName = "X-Crowdsec-Appsec-Verb" - HostHeaderName = "X-Crowdsec-Appsec-Host" - IPHeaderName = "X-Crowdsec-Appsec-Ip" - APIKeyHeaderName = "X-Crowdsec-Appsec-Api-Key" + URIHeaderName = "X-Crowdsec-Appsec-Uri" + VerbHeaderName = "X-Crowdsec-Appsec-Verb" + HostHeaderName = "X-Crowdsec-Appsec-Host" + IPHeaderName = "X-Crowdsec-Appsec-Ip" + APIKeyHeaderName = "X-Crowdsec-Appsec-Api-Key" + UserAgentHeaderName = "X-Crowdsec-Appsec-User-Agent" ) type ParsedRequest struct { @@ -275,7 +275,7 @@ func (r *ReqDumpFilter) ToJSON() error { } // Generate a ParsedRequest from a http.Request. ParsedRequest can be consumed by the App security Engine -func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedRequest, error) { +func NewParsedRequestFromRequest(r *http.Request, logger *log.Entry) (ParsedRequest, error) { var err error contentLength := r.ContentLength if contentLength < 0 { @@ -311,11 +311,15 @@ func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedR logger.Debugf("missing '%s' header", HostHeaderName) } + userAgent := r.Header.Get(UserAgentHeaderName) //This one is optional + // delete those headers before coraza process the request delete(r.Header, IPHeaderName) delete(r.Header, HostHeaderName) delete(r.Header, URIHeaderName) delete(r.Header, VerbHeaderName) + delete(r.Header, UserAgentHeaderName) + delete(r.Header, APIKeyHeaderName) originalHTTPRequest := r.Clone(r.Context()) originalHTTPRequest.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -323,6 +327,14 @@ func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedR originalHTTPRequest.RequestURI = clientURI originalHTTPRequest.Method = clientMethod originalHTTPRequest.Host = clientHost + if userAgent != "" { + originalHTTPRequest.Header.Set("User-Agent", userAgent) + r.Header.Set("User-Agent", userAgent) //Override the UA in the original request, as this is what will be used by the waf engine + } else { + //If we don't have a forwarded UA, delete the one that was set by the remediation in both original and incoming + originalHTTPRequest.Header.Del("User-Agent") + r.Header.Del("User-Agent") + } parsedURL, err := url.Parse(clientURI) if err != nil { @@ -330,6 +342,10 @@ func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedR } var remoteAddrNormalized string + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:65535" + } + // TODO we need to implement forwrded headers host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { log.Errorf("Invalid appsec remote IP source %v: %s", r.RemoteAddr, err.Error()) @@ -349,14 +365,14 @@ func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedR UUID: uuid.New().String(), ClientHost: clientHost, ClientIP: clientIP, - URI: parsedURL.Path, + URI: clientURI, Method: clientMethod, - Host: r.Host, + Host: clientHost, Headers: r.Header, - URL: r.URL, + URL: parsedURL, Proto: r.Proto, Body: body, - Args: parsedURL.Query(), //TODO: Check if there's not potential bypass as it excludes malformed args + Args: ParseQuery(parsedURL.RawQuery), TransferEncoding: r.TransferEncoding, ResponseChannel: make(chan AppsecTempResponse), RemoteAddrNormalized: remoteAddrNormalized, diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 1fd65dc38c3..8a696caf1f4 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -2,6 +2,7 @@ package cache import ( "errors" + "fmt" "time" "github.com/bluele/gcache" @@ -11,9 +12,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var Caches []gcache.Cache -var CacheNames []string -var CacheConfig []CacheCfg +var ( + Caches []gcache.Cache + CacheNames []string + CacheConfig []CacheCfg +) /*prometheus*/ var CacheMetrics = prometheus.NewGaugeVec( @@ -27,6 +30,7 @@ var CacheMetrics = prometheus.NewGaugeVec( // UpdateCacheMetrics is called directly by the prom handler func UpdateCacheMetrics() { CacheMetrics.Reset() + for i, name := range CacheNames { CacheMetrics.With(prometheus.Labels{"name": name, "type": CacheConfig[i].Strategy}).Set(float64(Caches[i].Len(false))) } @@ -42,27 +46,28 @@ type CacheCfg struct { } func CacheInit(cfg CacheCfg) error { - for _, name := range CacheNames { if name == cfg.Name { log.Infof("Cache %s already exists", cfg.Name) } } - //get a default logger + // get a default logger if cfg.LogLevel == nil { cfg.LogLevel = new(log.Level) *cfg.LogLevel = log.InfoLevel } - var clog = log.New() + + clog := log.New() + if err := types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating cache logger : %s", err) + return fmt.Errorf("while creating cache logger: %w", err) } + clog.SetLevel(*cfg.LogLevel) - cfg.Logger = clog.WithFields(log.Fields{ - "cache": cfg.Name, - }) + cfg.Logger = clog.WithField("cache", cfg.Name) tmpCache := gcache.New(cfg.Size) + switch cfg.Strategy { case "LRU": tmpCache = tmpCache.LRU() @@ -73,7 +78,6 @@ func CacheInit(cfg CacheCfg) error { default: cfg.Strategy = "LRU" tmpCache = tmpCache.LRU() - } CTICache := tmpCache.Build() @@ -85,36 +89,42 @@ func CacheInit(cfg CacheCfg) error { } func SetKey(cacheName string, key string, value string, expiration *time.Duration) error { - for i, name := range CacheNames { if name == cacheName { if expiration == nil { expiration = &CacheConfig[i].TTL } + CacheConfig[i].Logger.Debugf("Setting key %s to %s with expiration %v", key, value, *expiration) + if err := Caches[i].SetWithExpire(key, value, *expiration); err != nil { CacheConfig[i].Logger.Warningf("While setting key %s in cache %s: %s", key, cacheName, err) } } } + return nil } func GetKey(cacheName string, key string) (string, error) { for i, name := range CacheNames { if name == cacheName { - if value, err := Caches[i].Get(key); err != nil { - //do not warn or log if key not found + value, err := Caches[i].Get(key) + if err != nil { + // do not warn or log if key not found if errors.Is(err, gcache.KeyNotFoundError) { return "", nil } CacheConfig[i].Logger.Warningf("While getting key %s in cache %s: %s", key, cacheName, err) + return "", err - } else { - return value.(string), nil } + + return value.(string), nil } } + log.Warningf("Cache %s not found", cacheName) + return "", nil } diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index cdff39e700f..3014b729a9e 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -1,6 +1,7 @@ package csconfig import ( + "bytes" "crypto/tls" "crypto/x509" "errors" @@ -12,8 +13,9 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/go-cs-lib/csstring" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/yamlpatch" @@ -59,16 +61,15 @@ type CTICfg struct { func (a *CTICfg) Load() error { if a.Key == nil { - *a.Enabled = false + a.Enabled = ptr.Of(false) } if a.Key != nil && *a.Key == "" { - return fmt.Errorf("empty cti key") + return errors.New("empty cti key") } if a.Enabled == nil { - a.Enabled = new(bool) - *a.Enabled = true + a.Enabled = ptr.Of(true) } if a.CacheTimeout == nil { @@ -92,9 +93,14 @@ func (o *OnlineApiClientCfg) Load() error { return err } - err = yaml.UnmarshalStrict(fcontent, o.Credentials) + dec := yaml.NewDecoder(bytes.NewReader(fcontent)) + dec.KnownFields(true) + + err = dec.Decode(o.Credentials) if err != nil { - return fmt.Errorf("failed unmarshaling api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + if !errors.Is(err, io.EOF) { + return fmt.Errorf("failed to parse api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + } } switch { @@ -120,9 +126,16 @@ func (l *LocalApiClientCfg) Load() error { return err } - err = yaml.UnmarshalStrict(fcontent, &l.Credentials) + configData := csstring.StrictExpand(string(fcontent), os.LookupEnv) + + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&l.Credentials) if err != nil { - return fmt.Errorf("failed unmarshaling api client credential configuration file '%s': %w", l.CredentialsFilePath, err) + if !errors.Is(err, io.EOF) { + return fmt.Errorf("failed to parse api client credential configuration file '%s': %w", l.CredentialsFilePath, err) + } } if l.Credentials == nil || l.Credentials.URL == "" { @@ -130,13 +143,26 @@ func (l *LocalApiClientCfg) Load() error { } if l.Credentials != nil && l.Credentials.URL != "" { - if !strings.HasSuffix(l.Credentials.URL, "/") { + // don't append a trailing slash if the URL is a unix socket + if strings.HasPrefix(l.Credentials.URL, "http") && !strings.HasSuffix(l.Credentials.URL, "/") { l.Credentials.URL += "/" } } - if l.Credentials.Login != "" && (l.Credentials.CertPath != "" || l.Credentials.KeyPath != "") { - return fmt.Errorf("user/password authentication and TLS authentication are mutually exclusive") + // is the configuration asking for client authentication via TLS? + credTLSClientAuth := l.Credentials.CertPath != "" || l.Credentials.KeyPath != "" + + // is the configuration asking for TLS encryption and server authentication? + credTLS := credTLSClientAuth || l.Credentials.CACertPath != "" + + credSocket := strings.HasPrefix(l.Credentials.URL, "/") + + if credTLS && credSocket { + return errors.New("cannot use TLS with a unix socket") + } + + if credTLSClientAuth && l.Credentials.Login != "" { + return errors.New("user/password authentication and TLS authentication are mutually exclusive") } if l.InsecureSkipVerify == nil { @@ -176,9 +202,10 @@ func (l *LocalApiClientCfg) Load() error { return nil } -func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { +func (c *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { trustedIPs := make([]net.IPNet, 0) - for _, ip := range lapiCfg.TrustedIPs { + + for _, ip := range c.TrustedIPs { cidr := toValidCIDR(ip) _, ipNet, err := net.ParseCIDR(cidr) @@ -209,31 +236,56 @@ type CapiWhitelist struct { Cidrs []*net.IPNet `yaml:"cidrs,omitempty"` } +type LocalAPIAutoRegisterCfg struct { + Enable *bool `yaml:"enabled"` + Token string `yaml:"token"` + AllowedRanges []string `yaml:"allowed_ranges,omitempty"` + AllowedRangesParsed []*net.IPNet `yaml:"-"` +} + /*local api service configuration*/ type LocalApiServerCfg struct { - Enable *bool `yaml:"enable"` - ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080 - TLS *TLSCfg `yaml:"tls"` - DbConfig *DatabaseCfg `yaml:"-"` - LogDir string `yaml:"-"` - LogMedia string `yaml:"-"` - OnlineClient *OnlineApiClientCfg `yaml:"online_client"` - ProfilesPath string `yaml:"profiles_path,omitempty"` - ConsoleConfigPath string `yaml:"console_path,omitempty"` - ConsoleConfig *ConsoleConfig `yaml:"-"` - Profiles []*ProfileCfg `yaml:"-"` - LogLevel *log.Level `yaml:"log_level"` - UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"` - TrustedProxies *[]string `yaml:"trusted_proxies,omitempty"` - CompressLogs *bool `yaml:"-"` - LogMaxSize int `yaml:"-"` - LogMaxAge int `yaml:"-"` - LogMaxFiles int `yaml:"-"` - TrustedIPs []string `yaml:"trusted_ips,omitempty"` - PapiLogLevel *log.Level `yaml:"papi_log_level"` - DisableRemoteLapiRegistration bool `yaml:"disable_remote_lapi_registration,omitempty"` - CapiWhitelistsPath string `yaml:"capi_whitelists_path,omitempty"` - CapiWhitelists *CapiWhitelist `yaml:"-"` + Enable *bool `yaml:"enable"` + ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080 + ListenSocket string `yaml:"listen_socket,omitempty"` + TLS *TLSCfg `yaml:"tls"` + DbConfig *DatabaseCfg `yaml:"-"` + LogDir string `yaml:"-"` + LogMedia string `yaml:"-"` + OnlineClient *OnlineApiClientCfg `yaml:"online_client"` + ProfilesPath string `yaml:"profiles_path,omitempty"` + ConsoleConfigPath string `yaml:"console_path,omitempty"` + ConsoleConfig *ConsoleConfig `yaml:"-"` + Profiles []*ProfileCfg `yaml:"-"` + LogLevel *log.Level `yaml:"log_level"` + UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"` + TrustedProxies *[]string `yaml:"trusted_proxies,omitempty"` + CompressLogs *bool `yaml:"-"` + LogMaxSize int `yaml:"-"` + LogMaxAge int `yaml:"-"` + LogMaxFiles int `yaml:"-"` + TrustedIPs []string `yaml:"trusted_ips,omitempty"` + PapiLogLevel *log.Level `yaml:"papi_log_level"` + DisableRemoteLapiRegistration bool `yaml:"disable_remote_lapi_registration,omitempty"` + CapiWhitelistsPath string `yaml:"capi_whitelists_path,omitempty"` + CapiWhitelists *CapiWhitelist `yaml:"-"` + AutoRegister *LocalAPIAutoRegisterCfg `yaml:"auto_registration,omitempty"` +} + +func (c *LocalApiServerCfg) ClientURL() string { + if c == nil { + return "" + } + + if c.ListenSocket != "" { + return c.ListenSocket + } + + if c.ListenURI != "" { + return "http://" + c.ListenURI + } + + return "" } func (c *Config) LoadAPIServer(inCli bool) error { @@ -243,7 +295,9 @@ func (c *Config) LoadAPIServer(inCli bool) error { if c.API.Server == nil { log.Warning("crowdsec local API is disabled") + c.DisableAPI = true + return nil } @@ -254,6 +308,7 @@ func (c *Config) LoadAPIServer(inCli bool) error { if !*c.API.Server.Enable { log.Warning("crowdsec local API is disabled because 'enable' is set to false") + c.DisableAPI = true } @@ -261,11 +316,11 @@ func (c *Config) LoadAPIServer(inCli bool) error { return nil } - if c.API.Server.ListenURI == "" { - return fmt.Errorf("no listen_uri specified") + if c.API.Server.ListenURI == "" && c.API.Server.ListenSocket == "" { + return errors.New("no listen_uri or listen_socket specified") } - //inherit log level from common, then api->server + // inherit log level from common, then api->server var logLevel log.Level if c.API.Server.LogLevel != nil { logLevel = *c.API.Server.LogLevel @@ -301,6 +356,14 @@ func (c *Config) LoadAPIServer(inCli bool) error { log.Infof("loaded capi whitelist from %s: %d IPs, %d CIDRs", c.API.Server.CapiWhitelistsPath, len(c.API.Server.CapiWhitelists.Ips), len(c.API.Server.CapiWhitelists.Cidrs)) } + if err := c.API.Server.LoadAutoRegister(); err != nil { + return err + } + + if c.API.Server.AutoRegister != nil && c.API.Server.AutoRegister.Enable != nil && *c.API.Server.AutoRegister.Enable && !inCli { + log.Infof("auto LAPI registration enabled for ranges %+v", c.API.Server.AutoRegister.AllowedRanges) + } + c.API.Server.LogDir = c.Common.LogDir c.API.Server.LogMedia = c.Common.LogMedia c.API.Server.CompressLogs = c.Common.CompressLogs @@ -349,7 +412,7 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) { decoder := yaml.NewDecoder(fd) if err := decoder.Decode(&fromCfg); err != nil { if errors.Is(err, io.EOF) { - return nil, fmt.Errorf("empty file") + return nil, errors.New("empty file") } return nil, err @@ -381,21 +444,21 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) { return ret, nil } -func (s *LocalApiServerCfg) LoadCapiWhitelists() error { - if s.CapiWhitelistsPath == "" { +func (c *LocalApiServerCfg) LoadCapiWhitelists() error { + if c.CapiWhitelistsPath == "" { return nil } - fd, err := os.Open(s.CapiWhitelistsPath) + fd, err := os.Open(c.CapiWhitelistsPath) if err != nil { - return fmt.Errorf("while opening capi whitelist file: %s", err) + return fmt.Errorf("while opening capi whitelist file: %w", err) } defer fd.Close() - s.CapiWhitelists, err = parseCapiWhitelists(fd) + c.CapiWhitelists, err = parseCapiWhitelists(fd) if err != nil { - return fmt.Errorf("while parsing capi whitelist file '%s': %w", s.CapiWhitelistsPath, err) + return fmt.Errorf("while parsing capi whitelist file '%s': %w", c.CapiWhitelistsPath, err) } return nil @@ -403,11 +466,51 @@ func (s *LocalApiServerCfg) LoadCapiWhitelists() error { func (c *Config) LoadAPIClient() error { if c.API == nil || c.API.Client == nil || c.API.Client.CredentialsFilePath == "" || c.DisableAgent { - return fmt.Errorf("no API client section in configuration") + return errors.New("no API client section in configuration") } - if err := c.API.Client.Load(); err != nil { - return err + return c.API.Client.Load() +} + +func (c *LocalApiServerCfg) LoadAutoRegister() error { + if c.AutoRegister == nil { + c.AutoRegister = &LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(false), + } + + return nil + } + + // Disable by default + if c.AutoRegister.Enable == nil { + c.AutoRegister.Enable = ptr.Of(false) + } + + if !*c.AutoRegister.Enable { + return nil + } + + if c.AutoRegister.Token == "" { + return errors.New("missing token value for api.server.auto_register") + } + + if len(c.AutoRegister.Token) < 32 { + return errors.New("token value for api.server.auto_register is too short (min 32 characters)") + } + + if c.AutoRegister.AllowedRanges == nil { + return errors.New("missing allowed_ranges value for api.server.auto_register") + } + + c.AutoRegister.AllowedRangesParsed = make([]*net.IPNet, 0, len(c.AutoRegister.AllowedRanges)) + + for _, ipRange := range c.AutoRegister.AllowedRanges { + _, ipNet, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("auto_register: failed to parse allowed range '%s': %w", ipRange, err) + } + + c.AutoRegister.AllowedRangesParsed = append(c.AutoRegister.AllowedRangesParsed, ipNet) } return nil diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index e22c78204e7..dff3c3afc8c 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -9,7 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -64,10 +64,10 @@ func TestLoadLocalApiClientCfg(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.Load() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } @@ -101,7 +101,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { CredentialsFilePath: "./testdata/bad_lapi-secrets.yaml", }, expected: &ApiCredentialsCfg{}, - expectedErr: "failed unmarshaling api server credentials", + expectedErr: "failed to parse api server credentials", }, { name: "missing field configuration", @@ -121,10 +121,10 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.Load() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } @@ -147,7 +147,11 @@ func TestLoadAPIServer(t *testing.T) { require.NoError(t, err) configData := os.ExpandEnv(string(fcontent)) - err = yaml.UnmarshalStrict([]byte(configData), &config) + + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&config) require.NoError(t, err) tests := []struct { @@ -187,7 +191,8 @@ func TestLoadAPIServer(t *testing.T) { DbConfig: &DatabaseCfg{ DbPath: "./testdata/test.db", Type: "sqlite", - MaxOpenConns: ptr.Of(DEFAULT_MAX_OPEN_CONNS), + MaxOpenConns: DEFAULT_MAX_OPEN_CONNS, + UseWal: ptr.Of(true), // autodetected DecisionBulkSize: defaultDecisionBulkSize, }, ConsoleConfigPath: DefaultConfigPath("console.yaml"), @@ -212,6 +217,12 @@ func TestLoadAPIServer(t *testing.T) { ProfilesPath: "./testdata/profiles.yaml", UseForwardedForHeaders: false, PapiLogLevel: &logLevel, + AutoRegister: &LocalAPIAutoRegisterCfg{ + Enable: ptr.Of(false), + Token: "", + AllowedRanges: nil, + AllowedRangesParsed: nil, + }, }, }, { @@ -238,10 +249,10 @@ func TestLoadAPIServer(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.LoadAPIServer(false) cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } @@ -301,10 +312,10 @@ func TestParseCapiWhitelists(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { wl, err := parseCapiWhitelists(strings.NewReader(tc.input)) cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } diff --git a/pkg/csconfig/config.go b/pkg/csconfig/config.go index a704414952e..3bbdf607187 100644 --- a/pkg/csconfig/config.go +++ b/pkg/csconfig/config.go @@ -1,18 +1,22 @@ // Package csconfig contains the configuration structures for crowdsec and cscli. - package csconfig import ( + "errors" "fmt" + "io" "os" "path/filepath" + "strings" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/csstring" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/yamlpatch" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" ) // defaultConfigDir is the base path to all configuration files, to be overridden in the Makefile */ @@ -25,7 +29,7 @@ var globalConfig = Config{} // Config contains top-level defaults -> overridden by configuration file -> overridden by CLI flags type Config struct { - //just a path to ourselves :p + // just a path to ourselves :p FilePath *string `yaml:"-"` Self []byte `yaml:"-"` Common *CommonCfg `yaml:"common,omitempty"` @@ -44,10 +48,12 @@ type Config struct { func NewConfig(configFile string, disableAgent bool, disableAPI bool, inCli bool) (*Config, string, error) { patcher := yamlpatch.NewPatcher(configFile, ".local") patcher.SetQuiet(inCli) + fcontent, err := patcher.MergedPatchContent() if err != nil { return nil, "", err } + configData := csstring.StrictExpand(string(fcontent), os.LookupEnv) cfg := Config{ FilePath: &configFile, @@ -55,10 +61,15 @@ func NewConfig(configFile string, disableAgent bool, disableAPI bool, inCli bool DisableAPI: disableAPI, } - err = yaml.UnmarshalStrict([]byte(configData), &cfg) + dec := yaml.NewDecoder(strings.NewReader(configData)) + dec.KnownFields(true) + + err = dec.Decode(&cfg) if err != nil { - // this is actually the "merged" yaml - return nil, "", fmt.Errorf("%s: %w", configFile, err) + if !errors.Is(err, io.EOF) { + // this is actually the "merged" yaml + return nil, "", fmt.Errorf("%s: %w", configFile, err) + } } if cfg.Prometheus == nil { @@ -109,7 +120,7 @@ func NewDefaultConfig() *Config { } prometheus := PrometheusCfg{ Enabled: true, - Level: "full", + Level: configuration.CFG_METRICS_FULL, } configPaths := ConfigurationPaths{ ConfigDir: DefaultConfigPath("."), @@ -147,7 +158,7 @@ func NewDefaultConfig() *Config { dbConfig := DatabaseCfg{ Type: "sqlite", DbPath: DefaultDataPath("crowdsec.db"), - MaxOpenConns: ptr.Of(DEFAULT_MAX_OPEN_CONNS), + MaxOpenConns: DEFAULT_MAX_OPEN_CONNS, } globalCfg := Config{ diff --git a/pkg/csconfig/config_paths.go b/pkg/csconfig/config_paths.go index 71e3bacdaac..a8d39a664f3 100644 --- a/pkg/csconfig/config_paths.go +++ b/pkg/csconfig/config_paths.go @@ -1,6 +1,7 @@ package csconfig import ( + "errors" "fmt" "path/filepath" ) @@ -9,31 +10,36 @@ type ConfigurationPaths struct { ConfigDir string `yaml:"config_dir"` DataDir string `yaml:"data_dir,omitempty"` SimulationFilePath string `yaml:"simulation_path,omitempty"` - HubIndexFile string `yaml:"index_path,omitempty"` //path of the .index.json + HubIndexFile string `yaml:"index_path,omitempty"` // path of the .index.json HubDir string `yaml:"hub_dir,omitempty"` PluginDir string `yaml:"plugin_dir,omitempty"` NotificationDir string `yaml:"notification_dir,omitempty"` + PatternDir string `yaml:"pattern_dir,omitempty"` } func (c *Config) loadConfigurationPaths() error { var err error if c.ConfigPaths == nil { - return fmt.Errorf("no configuration paths provided") + return errors.New("no configuration paths provided") } if c.ConfigPaths.DataDir == "" { - return fmt.Errorf("please provide a data directory with the 'data_dir' directive in the 'config_paths' section") + return errors.New("please provide a data directory with the 'data_dir' directive in the 'config_paths' section") } if c.ConfigPaths.HubDir == "" { - c.ConfigPaths.HubDir = filepath.Clean(c.ConfigPaths.ConfigDir + "/hub") + c.ConfigPaths.HubDir = filepath.Join(c.ConfigPaths.ConfigDir, "hub") } if c.ConfigPaths.HubIndexFile == "" { - c.ConfigPaths.HubIndexFile = filepath.Clean(c.ConfigPaths.HubDir + "/.index.json") + c.ConfigPaths.HubIndexFile = filepath.Join(c.ConfigPaths.HubDir, ".index.json") } - var configPathsCleanup = []*string{ + if c.ConfigPaths.PatternDir == "" { + c.ConfigPaths.PatternDir = filepath.Join(c.ConfigPaths.ConfigDir, "patterns") + } + + configPathsCleanup := []*string{ &c.ConfigPaths.HubDir, &c.ConfigPaths.HubIndexFile, &c.ConfigPaths.ConfigDir, @@ -41,6 +47,7 @@ func (c *Config) loadConfigurationPaths() error { &c.ConfigPaths.SimulationFilePath, &c.ConfigPaths.PluginDir, &c.ConfigPaths.NotificationDir, + &c.ConfigPaths.PatternDir, } for _, k := range configPathsCleanup { if *k == "" { diff --git a/pkg/csconfig/config_test.go b/pkg/csconfig/config_test.go index 4843c2f70f9..b69954de178 100644 --- a/pkg/csconfig/config_test.go +++ b/pkg/csconfig/config_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/cstest" ) @@ -32,7 +32,6 @@ func TestNewCrowdSecConfig(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { result := &Config{} assert.Equal(t, tc.expected, result) @@ -43,5 +42,5 @@ func TestNewCrowdSecConfig(t *testing.T) { func TestDefaultConfig(t *testing.T) { x := NewDefaultConfig() _, err := yaml.Marshal(x) - require.NoError(t, err, "failed marshaling config: %s", err) + require.NoError(t, err, "failed to serialize config: %s", err) } diff --git a/pkg/csconfig/console.go b/pkg/csconfig/console.go index 1e8974154ec..21ecbf3d736 100644 --- a/pkg/csconfig/console.go +++ b/pkg/csconfig/console.go @@ -5,7 +5,7 @@ import ( "os" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/ptr" ) @@ -37,10 +37,40 @@ type ConsoleConfig struct { ShareContext *bool `yaml:"share_context"` } +func (c *ConsoleConfig) EnabledOptions() []string { + ret := []string{} + if c == nil { + return ret + } + + if c.ShareCustomScenarios != nil && *c.ShareCustomScenarios { + ret = append(ret, SEND_CUSTOM_SCENARIOS) + } + + if c.ShareTaintedScenarios != nil && *c.ShareTaintedScenarios { + ret = append(ret, SEND_TAINTED_SCENARIOS) + } + + if c.ShareManualDecisions != nil && *c.ShareManualDecisions { + ret = append(ret, SEND_MANUAL_SCENARIOS) + } + + if c.ConsoleManagement != nil && *c.ConsoleManagement { + ret = append(ret, CONSOLE_MANAGEMENT) + } + + if c.ShareContext != nil && *c.ShareContext { + ret = append(ret, SEND_CONTEXT) + } + + return ret +} + func (c *ConsoleConfig) IsPAPIEnabled() bool { if c == nil || c.ConsoleManagement == nil { return false } + return *c.ConsoleManagement } @@ -48,31 +78,36 @@ func (c *LocalApiServerCfg) LoadConsoleConfig() error { c.ConsoleConfig = &ConsoleConfig{} if _, err := os.Stat(c.ConsoleConfigPath); err != nil && os.IsNotExist(err) { log.Debugf("no console configuration to load") + c.ConsoleConfig.ShareCustomScenarios = ptr.Of(true) c.ConsoleConfig.ShareTaintedScenarios = ptr.Of(true) c.ConsoleConfig.ShareManualDecisions = ptr.Of(false) c.ConsoleConfig.ConsoleManagement = ptr.Of(false) c.ConsoleConfig.ShareContext = ptr.Of(false) + return nil } yamlFile, err := os.ReadFile(c.ConsoleConfigPath) if err != nil { - return fmt.Errorf("reading console config file '%s': %s", c.ConsoleConfigPath, err) + return fmt.Errorf("reading console config file '%s': %w", c.ConsoleConfigPath, err) } + err = yaml.Unmarshal(yamlFile, c.ConsoleConfig) if err != nil { - return fmt.Errorf("unmarshaling console config file '%s': %s", c.ConsoleConfigPath, err) + return fmt.Errorf("parsing console config file '%s': %w", c.ConsoleConfigPath, err) } if c.ConsoleConfig.ShareCustomScenarios == nil { log.Debugf("no share_custom scenarios found, setting to true") c.ConsoleConfig.ShareCustomScenarios = ptr.Of(true) } + if c.ConsoleConfig.ShareTaintedScenarios == nil { log.Debugf("no share_tainted scenarios found, setting to true") c.ConsoleConfig.ShareTaintedScenarios = ptr.Of(true) } + if c.ConsoleConfig.ShareManualDecisions == nil { log.Debugf("no share_manual scenarios found, setting to false") c.ConsoleConfig.ShareManualDecisions = ptr.Of(false) diff --git a/pkg/csconfig/crowdsec_service.go b/pkg/csconfig/crowdsec_service.go index 36d38cf7481..cf796805dee 100644 --- a/pkg/csconfig/crowdsec_service.go +++ b/pkg/csconfig/crowdsec_service.go @@ -6,7 +6,7 @@ import ( "path/filepath" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/ptr" ) @@ -133,27 +133,24 @@ func (c *Config) LoadCrowdsec() error { } if err = c.LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %s", err) + return fmt.Errorf("loading api client: %w", err) } return nil } func (c *CrowdsecServiceCfg) DumpContextConfigFile() error { - var out []byte - var err error - // XXX: MakeDirs - - if out, err = yaml.Marshal(c.ContextToSend); err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) + out, err := yaml.Marshal(c.ContextToSend) + if err != nil { + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) } - if err = os.MkdirAll(filepath.Dir(c.ConsoleContextPath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(c.ConsoleContextPath), 0o700); err != nil { return fmt.Errorf("while creating directories for %s: %w", c.ConsoleContextPath, err) } - if err := os.WriteFile(c.ConsoleContextPath, out, 0600); err != nil { + if err := os.WriteFile(c.ConsoleContextPath, out, 0o600); err != nil { return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleContextPath, err) } diff --git a/pkg/csconfig/crowdsec_service_test.go b/pkg/csconfig/crowdsec_service_test.go index 8d332271b03..7570b63011e 100644 --- a/pkg/csconfig/crowdsec_service_test.go +++ b/pkg/csconfig/crowdsec_service_test.go @@ -61,9 +61,9 @@ func TestLoadCrowdsec(t *testing.T) { AcquisitionFiles: []string{acquisFullPath}, SimulationFilePath: "./testdata/simulation.yaml", // context is loaded in pkg/alertcontext -// ContextToSend: map[string][]string{ -// "source_ip": {"evt.Parsed.source_ip"}, -// }, + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), }, @@ -100,9 +100,9 @@ func TestLoadCrowdsec(t *testing.T) { ConsoleContextValueLength: 0, AcquisitionFiles: []string{acquisFullPath, acquisInDirFullPath}, // context is loaded in pkg/alertcontext -// ContextToSend: map[string][]string{ -// "source_ip": {"evt.Parsed.source_ip"}, -// }, + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, SimulationFilePath: "./testdata/simulation.yaml", SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), @@ -139,9 +139,9 @@ func TestLoadCrowdsec(t *testing.T) { AcquisitionFiles: []string{}, SimulationFilePath: "", // context is loaded in pkg/alertcontext -// ContextToSend: map[string][]string{ -// "source_ip": {"evt.Parsed.source_ip"}, -// }, + // ContextToSend: map[string][]string{ + // "source_ip": {"evt.Parsed.source_ip"}, + // }, SimulationConfig: &SimulationConfig{ Simulation: ptr.Of(false), }, @@ -181,10 +181,10 @@ func TestLoadCrowdsec(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.LoadCrowdsec() cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } diff --git a/pkg/csconfig/cscli.go b/pkg/csconfig/cscli.go index 7fff03864ef..9393156c0ed 100644 --- a/pkg/csconfig/cscli.go +++ b/pkg/csconfig/cscli.go @@ -6,18 +6,18 @@ import ( /*cscli specific config, such as hub directory*/ type CscliCfg struct { - Output string `yaml:"output,omitempty"` - Color string `yaml:"color,omitempty"` - HubBranch string `yaml:"hub_branch"` - HubURLTemplate string `yaml:"__hub_url_template__,omitempty"` - SimulationConfig *SimulationConfig `yaml:"-"` - DbConfig *DatabaseCfg `yaml:"-"` - - SimulationFilePath string `yaml:"-"` - PrometheusUrl string `yaml:"prometheus_uri"` + Output string `yaml:"output,omitempty"` + Color string `yaml:"color,omitempty"` + HubBranch string `yaml:"hub_branch"` + HubURLTemplate string `yaml:"__hub_url_template__,omitempty"` + SimulationConfig *SimulationConfig `yaml:"-"` + DbConfig *DatabaseCfg `yaml:"-"` + + SimulationFilePath string `yaml:"-"` + PrometheusUrl string `yaml:"prometheus_uri"` } -const defaultHubURLTemplate = "https://hub-cdn.crowdsec.net/%s/%s" +const defaultHubURLTemplate = "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s" func (c *Config) loadCSCLI() error { if c.Cscli == nil { diff --git a/pkg/csconfig/cscli_test.go b/pkg/csconfig/cscli_test.go index 807f02d216c..a58fdd6f857 100644 --- a/pkg/csconfig/cscli_test.go +++ b/pkg/csconfig/cscli_test.go @@ -39,7 +39,6 @@ func TestLoadCSCLI(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.loadCSCLI() cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 5149b4ae39e..4ca582cf576 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -1,13 +1,17 @@ package csconfig import ( + "errors" "fmt" + "path/filepath" "time" "entgo.io/ent/dialect" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) const ( @@ -29,7 +33,7 @@ type DatabaseCfg struct { Type string `yaml:"type"` Flush *FlushDBCfg `yaml:"flush"` LogLevel *log.Level `yaml:"log_level"` - MaxOpenConns *int `yaml:"max_open_conns,omitempty"` + MaxOpenConns int `yaml:"max_open_conns,omitempty"` UseWal *bool `yaml:"use_wal,omitempty"` DecisionBulkSize int `yaml:"decision_bulk_size,omitempty"` } @@ -44,15 +48,17 @@ type AuthGCCfg struct { } type FlushDBCfg struct { - MaxItems *int `yaml:"max_items,omitempty"` - MaxAge *string `yaml:"max_age,omitempty"` - BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"` - AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"` + MaxItems *int `yaml:"max_items,omitempty"` + // We could unmarshal as time.Duration, but alert filters right now are a map of strings + MaxAge *string `yaml:"max_age,omitempty"` + BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"` + AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"` + MetricsMaxAge *time.Duration `yaml:"metrics_max_age,omitempty"` } func (c *Config) LoadDBConfig(inCli bool) error { if c.DbConfig == nil { - return fmt.Errorf("no database configuration provided") + return errors.New("no database configuration provided") } if c.Cscli != nil { @@ -63,8 +69,38 @@ func (c *Config) LoadDBConfig(inCli bool) error { c.API.Server.DbConfig = c.DbConfig } - if c.DbConfig.MaxOpenConns == nil { - c.DbConfig.MaxOpenConns = ptr.Of(DEFAULT_MAX_OPEN_CONNS) + if c.DbConfig.MaxOpenConns == 0 { + c.DbConfig.MaxOpenConns = DEFAULT_MAX_OPEN_CONNS + } + + if !inCli && c.DbConfig.Type == "sqlite" { + if c.DbConfig.UseWal == nil { + dbDir := filepath.Dir(c.DbConfig.DbPath) + isNetwork, fsType, err := types.IsNetworkFS(dbDir) + switch { + case err != nil: + log.Warnf("unable to determine if database is on network filesystem: %s", err) + log.Warning( + "You are using sqlite without WAL, this can have a performance impact. " + + "If you do not store the database in a network share, set db_config.use_wal to true. " + + "Set explicitly to false to disable this warning.") + case isNetwork: + log.Debugf("database is on network filesystem (%s), setting useWal to false", fsType) + c.DbConfig.UseWal = ptr.Of(false) + default: + log.Debugf("database is on local filesystem (%s), setting useWal to true", fsType) + c.DbConfig.UseWal = ptr.Of(true) + } + } else if *c.DbConfig.UseWal { + dbDir := filepath.Dir(c.DbConfig.DbPath) + isNetwork, fsType, err := types.IsNetworkFS(dbDir) + switch { + case err != nil: + log.Warnf("unable to determine if database is on network filesystem: %s", err) + case isNetwork: + log.Warnf("database seems to be stored on a network share (%s), but useWal is set to true. Proceed at your own risk.", fsType) + } + } } if c.DbConfig.DecisionBulkSize == 0 { @@ -77,15 +113,12 @@ func (c *Config) LoadDBConfig(inCli bool) error { c.DbConfig.DecisionBulkSize = maxDecisionBulkSize } - if !inCli && c.DbConfig.Type == "sqlite" && c.DbConfig.UseWal == nil { - log.Warning("You are using sqlite without WAL, this can have a performance impact. If you do not store the database in a network share, set db_config.use_wal to true. Set explicitly to false to disable this warning.") - } - return nil } func (d *DatabaseCfg) ConnectionString() string { connString := "" + switch d.Type { case "sqlite": var sqliteConnectionStringParameters string @@ -94,6 +127,7 @@ func (d *DatabaseCfg) ConnectionString() string { } else { sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1" } + connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters) case "mysql": if d.isSocketConfig() { @@ -101,6 +135,10 @@ func (d *DatabaseCfg) ConnectionString() string { } else { connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName) } + + if d.Sslmode != "" { + connString = fmt.Sprintf("%s&tls=%s", connString, d.Sslmode) + } case "postgres", "postgresql", "pgx": if d.isSocketConfig() { connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password) @@ -108,6 +146,7 @@ func (d *DatabaseCfg) ConnectionString() string { connString = fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", d.Host, d.Port, d.User, d.DbName, d.Password, d.Sslmode) } } + return connString } @@ -121,8 +160,10 @@ func (d *DatabaseCfg) ConnectionDialect() (string, string, error) { if d.Type != "pgx" { log.Debugf("database type '%s' is deprecated, switching to 'pgx' instead", d.Type) } + return "pgx", dialect.Postgres, nil } + return "", "", fmt.Errorf("unknown database type '%s'", d.Type) } diff --git a/pkg/csconfig/database_test.go b/pkg/csconfig/database_test.go index a946025799d..4a1ef807f97 100644 --- a/pkg/csconfig/database_test.go +++ b/pkg/csconfig/database_test.go @@ -22,7 +22,7 @@ func TestLoadDBConfig(t *testing.T) { DbConfig: &DatabaseCfg{ Type: "sqlite", DbPath: "./testdata/test.db", - MaxOpenConns: ptr.Of(10), + MaxOpenConns: 10, }, Cscli: &CscliCfg{}, API: &APICfg{ @@ -32,7 +32,8 @@ func TestLoadDBConfig(t *testing.T) { expected: &DatabaseCfg{ Type: "sqlite", DbPath: "./testdata/test.db", - MaxOpenConns: ptr.Of(10), + MaxOpenConns: 10, + UseWal: ptr.Of(true), DecisionBulkSize: defaultDecisionBulkSize, }, }, @@ -45,10 +46,10 @@ func TestLoadDBConfig(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.LoadDBConfig(false) cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } diff --git a/pkg/csconfig/fflag.go b/pkg/csconfig/fflag.go index 7311f9e751a..c86686889eb 100644 --- a/pkg/csconfig/fflag.go +++ b/pkg/csconfig/fflag.go @@ -12,10 +12,7 @@ import ( // LoadFeatureFlagsEnv parses the environment variables to enable feature flags. func LoadFeatureFlagsEnv(logger *log.Logger) error { - if err := fflag.Crowdsec.SetFromEnv(logger); err != nil { - return err - } - return nil + return fflag.Crowdsec.SetFromEnv(logger) } // FeatureFlagsFileLocation returns the path to the feature.yaml file. diff --git a/pkg/csconfig/hub_test.go b/pkg/csconfig/hub_test.go index 2f9528c6043..49d010a04f4 100644 --- a/pkg/csconfig/hub_test.go +++ b/pkg/csconfig/hub_test.go @@ -35,7 +35,6 @@ func TestLoadHub(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.loadHub() cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csconfig/profiles.go b/pkg/csconfig/profiles.go index ad3779ed12f..6fbb8ed8b21 100644 --- a/pkg/csconfig/profiles.go +++ b/pkg/csconfig/profiles.go @@ -6,7 +6,7 @@ import ( "fmt" "io" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/yamlpatch" @@ -23,43 +23,50 @@ import ( type ProfileCfg struct { Name string `yaml:"name,omitempty"` Debug *bool `yaml:"debug,omitempty"` - Filters []string `yaml:"filters,omitempty"` //A list of OR'ed expressions. the models.Alert object + Filters []string `yaml:"filters,omitempty"` // A list of OR'ed expressions. the models.Alert object Decisions []models.Decision `yaml:"decisions,omitempty"` DurationExpr string `yaml:"duration_expr,omitempty"` - OnSuccess string `yaml:"on_success,omitempty"` //continue or break - OnFailure string `yaml:"on_failure,omitempty"` //continue or break - OnError string `yaml:"on_error,omitempty"` //continue, break, error, report, apply, ignore + OnSuccess string `yaml:"on_success,omitempty"` // continue or break + OnFailure string `yaml:"on_failure,omitempty"` // continue or break + OnError string `yaml:"on_error,omitempty"` // continue, break, error, report, apply, ignore Notifications []string `yaml:"notifications,omitempty"` } func (c *LocalApiServerCfg) LoadProfiles() error { if c.ProfilesPath == "" { - return fmt.Errorf("empty profiles path") + return errors.New("empty profiles path") } patcher := yamlpatch.NewPatcher(c.ProfilesPath, ".local") + fcontent, err := patcher.PrependedPatchContent() if err != nil { return err } + reader := bytes.NewReader(fcontent) dec := yaml.NewDecoder(reader) - dec.SetStrict(true) + dec.KnownFields(true) + for { t := ProfileCfg{} + err = dec.Decode(&t) if err != nil { if errors.Is(err, io.EOF) { break } + return fmt.Errorf("while decoding %s: %w", c.ProfilesPath, err) } + c.Profiles = append(c.Profiles, &t) } if len(c.Profiles) == 0 { - return fmt.Errorf("zero profiles loaded for LAPI") + return errors.New("zero profiles loaded for LAPI") } + return nil } diff --git a/pkg/csconfig/simulation.go b/pkg/csconfig/simulation.go index 0d09aa478ff..c9041df464a 100644 --- a/pkg/csconfig/simulation.go +++ b/pkg/csconfig/simulation.go @@ -1,10 +1,13 @@ package csconfig import ( + "bytes" + "errors" "fmt" + "io" "path/filepath" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/yamlpatch" ) @@ -20,37 +23,50 @@ func (s *SimulationConfig) IsSimulated(scenario string) bool { if s.Simulation != nil && *s.Simulation { simulated = true } + for _, excluded := range s.Exclusions { if excluded == scenario { - simulated = !simulated - break + return !simulated } } + return simulated } func (c *Config) LoadSimulation() error { simCfg := SimulationConfig{} + if c.ConfigPaths.SimulationFilePath == "" { - c.ConfigPaths.SimulationFilePath = filepath.Clean(c.ConfigPaths.ConfigDir + "/simulation.yaml") + c.ConfigPaths.SimulationFilePath = filepath.Join(c.ConfigPaths.ConfigDir, "simulation.yaml") } patcher := yamlpatch.NewPatcher(c.ConfigPaths.SimulationFilePath, ".local") + rcfg, err := patcher.MergedPatchContent() if err != nil { return err } - if err := yaml.UnmarshalStrict(rcfg, &simCfg); err != nil { - return fmt.Errorf("while unmarshaling simulation file '%s' : %s", c.ConfigPaths.SimulationFilePath, err) + + dec := yaml.NewDecoder(bytes.NewReader(rcfg)) + dec.KnownFields(true) + + if err := dec.Decode(&simCfg); err != nil { + if !errors.Is(err, io.EOF) { + return fmt.Errorf("while parsing simulation file '%s': %w", c.ConfigPaths.SimulationFilePath, err) + } } + if simCfg.Simulation == nil { simCfg.Simulation = new(bool) } + if c.Crowdsec != nil { c.Crowdsec.SimulationConfig = &simCfg } + if c.Cscli != nil { c.Cscli.SimulationConfig = &simCfg } + return nil } diff --git a/pkg/csconfig/simulation_test.go b/pkg/csconfig/simulation_test.go index 01f05e3975a..a1e5f0a5b02 100644 --- a/pkg/csconfig/simulation_test.go +++ b/pkg/csconfig/simulation_test.go @@ -60,7 +60,7 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: "while unmarshaling simulation file './testdata/config.yaml' : yaml: unmarshal errors", + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, { name: "basic bad file content", @@ -71,12 +71,11 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: "while unmarshaling simulation file './testdata/config.yaml' : yaml: unmarshal errors", + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { err := tc.input.LoadSimulation() cstest.RequireErrorContains(t, err, tc.expectedErr) @@ -124,7 +123,6 @@ func TestIsSimulated(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { isSimulated := tc.SimulationConfig.IsSimulated(tc.Input) require.Equal(t, tc.expected, isSimulated) diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index b5c86f224ab..e996fa9b68c 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -45,7 +45,7 @@ type PluginBroker struct { pluginConfigByName map[string]PluginConfig pluginMap map[string]plugin.Plugin notificationConfigsByPluginType map[string][][]byte // "slack" -> []{config1, config2} - notificationPluginByName map[string]Notifier + notificationPluginByName map[string]protobufs.NotifierServer watcher PluginWatcher pluginKillMethods []func() pluginProcConfig *csconfig.PluginCfg @@ -72,10 +72,10 @@ type ProfileAlert struct { Alert *models.Alert } -func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { +func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { pb.PluginChannel = make(chan ProfileAlert) pb.notificationConfigsByPluginType = make(map[string][][]byte) - pb.notificationPluginByName = make(map[string]Notifier) + pb.notificationPluginByName = make(map[string]protobufs.NotifierServer) pb.pluginMap = make(map[string]plugin.Plugin) pb.pluginConfigByName = make(map[string]PluginConfig) pb.alertsByPluginName = make(map[string][]*models.Alert) @@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs if err := pb.loadConfig(configPaths.NotificationDir); err != nil { return fmt.Errorf("while loading plugin config: %w", err) } - if err := pb.loadPlugins(configPaths.PluginDir); err != nil { + if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil { return fmt.Errorf("while loading plugin: %w", err) } pb.watcher = PluginWatcher{} @@ -103,7 +103,6 @@ func (pb *PluginBroker) Kill() { func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { //we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) pb.watcher.Start(&tomb.Tomb{}) -loop: for { select { case profileAlert := <-pb.PluginChannel: @@ -137,7 +136,7 @@ loop: case <-pb.watcher.tomb.Dead(): log.Info("killing all plugins") pb.Kill() - break loop + return case pluginName := <-pb.watcher.PluginEvents: // this can be run in goroutine, but then locks will be needed pluginMutex.Lock() @@ -231,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error { return nil } -func (pb *PluginBroker) loadPlugins(path string) error { +func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error { binaryPaths, err := listFilesAtPath(path) if err != nil { return err @@ -266,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return err } data = []byte(csstring.StrictExpand(string(data), os.LookupEnv)) - _, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data}) + _, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data}) if err != nil { return fmt.Errorf("while configuring %s: %w", pc.Name, err) } @@ -277,7 +276,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return pb.verifyPluginBinaryWithProfile() } -func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (Notifier, error) { +func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (protobufs.NotifierServer, error) { handshake, err := getHandshake() if err != nil { @@ -314,7 +313,7 @@ func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) ( return nil, err } pb.pluginKillMethods = append(pb.pluginKillMethods, c.Kill) - return raw.(Notifier), nil + return raw.(protobufs.NotifierServer), nil } func (pb *PluginBroker) pushNotificationsToPlugin(pluginName string, alerts []*models.Alert) error { diff --git a/pkg/csplugin/broker_suite_test.go b/pkg/csplugin/broker_suite_test.go index 778bb2dfe2e..1210c67058a 100644 --- a/pkg/csplugin/broker_suite_test.go +++ b/pkg/csplugin/broker_suite_test.go @@ -1,6 +1,7 @@ package csplugin import ( + "context" "io" "os" "os/exec" @@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() { func (s *PluginSuite) SetupSubTest() { var err error + t := s.T() s.runDir, err = os.MkdirTemp("", "cs_plugin_test") @@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() { func (s *PluginSuite) TearDownSubTest() { t := s.T() + if s.pluginBroker != nil { s.pluginBroker.Kill() s.pluginBroker = nil @@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() { os.Remove("./out") } -func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) { +func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) { pb := PluginBroker{} + if procCfg == nil { procCfg = &csconfig.PluginCfg{} } + profiles := csconfig.NewDefaultConfig().API.Server.Profiles profiles = append(profiles, &csconfig.ProfileCfg{ Notifications: []string{"dummy_default"}, }) - err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{ + + err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{ PluginDir: s.pluginDir, NotificationDir: s.notifDir, }) + s.pluginBroker = &pb + return s.pluginBroker, err } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index 9adb35ad7cc..ae5a615b489 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -14,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -38,7 +39,7 @@ func (s *PluginSuite) readconfig() PluginConfig { require.NoError(t, err, "unable to read config file %s", s.pluginConfig) err = yaml.Unmarshal(orig, &config) - require.NoError(t, err, "unable to unmarshal config file") + require.NoError(t, err, "unable to parse config file") return config } @@ -46,13 +47,14 @@ func (s *PluginSuite) readconfig() PluginConfig { func (s *PluginSuite) writeconfig(config PluginConfig) { t := s.T() data, err := yaml.Marshal(&config) - require.NoError(t, err, "unable to marshal config file") + require.NoError(t, err, "unable to serialize config file") - err = os.WriteFile(s.pluginConfig, data, 0644) + err = os.WriteFile(s.pluginConfig, data, 0o644) require.NoError(t, err, "unable to write config file %s", s.pluginConfig) } func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -129,26 +131,28 @@ func (s *PluginSuite) TestBrokerInit() { } for _, tc := range tests { - tc := tc s.Run(tc.name, func() { t := s.T() if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerNoThreshold() { + ctx := context.Background() + var alerts []models.Alert DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { + ctx := context.Background() + // test grouping by "time" DefaultEmptyTicker = 50 * time.Millisecond @@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { cfg.GroupWait = 4 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { } func (s *PluginSuite) TestBrokerRunGroupThreshold() { + ctx := context.Background() // test grouping by "size" DefaultEmptyTicker = 50 * time.Millisecond @@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { cfg.GroupThreshold = 4 s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { } func (s *PluginSuite) TestBrokerRunTimeThreshold() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { } func (s *PluginSuite) TestBrokerRunSimple() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 97a3ad33deb..570f23e5015 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions */ func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -54,22 +56,22 @@ func (s *PluginSuite) TestBrokerInit() { } for _, tc := range tests { - tc := tc s.Run(tc.name, func() { t := s.T() if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerRun() { + ctx := context.Background() t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/hclog_adapter.go b/pkg/csplugin/hclog_adapter.go index 9550e4b4539..44a22463709 100644 --- a/pkg/csplugin/hclog_adapter.go +++ b/pkg/csplugin/hclog_adapter.go @@ -221,14 +221,13 @@ func merge(dst map[string]interface{}, k, v interface{}) { func safeString(str fmt.Stringer) (s string) { defer func() { if panicVal := recover(); panicVal != nil { - if v := reflect.ValueOf(str); v.Kind() == reflect.Ptr && v.IsNil() { - s = "NULL" - } else { + if v := reflect.ValueOf(str); v.Kind() != reflect.Ptr || !v.IsNil() { panic(panicVal) } + s = "NULL" } }() s = str.String() - return + return //nolint:revive // bare return for the defer } diff --git a/pkg/csplugin/helpers.go b/pkg/csplugin/helpers.go index 75ee773b808..915f17e5dd3 100644 --- a/pkg/csplugin/helpers.go +++ b/pkg/csplugin/helpers.go @@ -5,9 +5,10 @@ import ( "os" "text/template" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" - log "github.com/sirupsen/logrus" ) var helpers = template.FuncMap{ diff --git a/pkg/csplugin/listfiles_test.go b/pkg/csplugin/listfiles_test.go index a7b41c51d07..c476d7a4e4a 100644 --- a/pkg/csplugin/listfiles_test.go +++ b/pkg/csplugin/listfiles_test.go @@ -21,7 +21,7 @@ func TestListFilesAtPath(t *testing.T) { require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "slack")) require.NoError(t, err) - err = os.Mkdir(filepath.Join(dir, "somedir"), 0755) + err = os.Mkdir(filepath.Join(dir, "somedir"), 0o755) require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "somedir", "inner")) require.NoError(t, err) @@ -47,7 +47,6 @@ func TestListFilesAtPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, err := listFilesAtPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/notifier.go b/pkg/csplugin/notifier.go index a4f5bbc0ed8..615322ac0c3 100644 --- a/pkg/csplugin/notifier.go +++ b/pkg/csplugin/notifier.go @@ -2,7 +2,7 @@ package csplugin import ( "context" - "fmt" + "errors" plugin "github.com/hashicorp/go-plugin" "google.golang.org/grpc" @@ -10,17 +10,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) -type Notifier interface { - Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) - Configure(ctx context.Context, cfg *protobufs.Config) (*protobufs.Empty, error) -} - type NotifierPlugin struct { plugin.Plugin - Impl Notifier + Impl protobufs.NotifierServer } -type GRPCClient struct{ client protobufs.NotifierClient } +type GRPCClient struct{ + protobufs.UnimplementedNotifierServer + client protobufs.NotifierClient +} func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { done := make(chan error) @@ -35,19 +33,17 @@ func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notific return &protobufs.Empty{}, err case <-ctx.Done(): - return &protobufs.Empty{}, fmt.Errorf("timeout exceeded") + return &protobufs.Empty{}, errors.New("timeout exceeded") } } func (m *GRPCClient) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { - _, err := m.client.Configure( - context.Background(), config, - ) + _, err := m.client.Configure(ctx, config) return &protobufs.Empty{}, err } type GRPCServer struct { - Impl Notifier + Impl protobufs.NotifierServer } func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { diff --git a/pkg/csplugin/utils.go b/pkg/csplugin/utils.go index 216a079d457..571d78add56 100644 --- a/pkg/csplugin/utils.go +++ b/pkg/csplugin/utils.go @@ -51,7 +51,7 @@ func getUID(username string) (uint32, error) { return 0, err } if uid < 0 || uid > math.MaxInt32 { - return 0, fmt.Errorf("out of bound uid") + return 0, errors.New("out of bound uid") } return uint32(uid), nil } @@ -66,7 +66,7 @@ func getGID(groupname string) (uint32, error) { return 0, err } if gid < 0 || gid > math.MaxInt32 { - return 0, fmt.Errorf("out of bound gid") + return 0, errors.New("out of bound gid") } return uint32(gid), nil } @@ -123,10 +123,10 @@ func pluginIsValid(path string) error { mode := details.Mode() perm := uint32(mode) - if (perm & 00002) != 0 { + if (perm & 0o0002) != 0 { return fmt.Errorf("plugin at %s is world writable, world writable plugins are invalid", path) } - if (perm & 00020) != 0 { + if (perm & 0o0020) != 0 { return fmt.Errorf("plugin at %s is group writable, group writable plugins are invalid", path) } if (mode & os.ModeSetgid) != 0 { diff --git a/pkg/csplugin/utils_test.go b/pkg/csplugin/utils_test.go index f02e7f491b2..7fa9a77acd5 100644 --- a/pkg/csplugin/utils_test.go +++ b/pkg/csplugin/utils_test.go @@ -37,7 +37,6 @@ func TestGetPluginNameAndTypeFromPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, got1, err := getPluginTypeAndSubtypeFromPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/utils_windows.go b/pkg/csplugin/utils_windows.go index dfb11aff548..91002079398 100644 --- a/pkg/csplugin/utils_windows.go +++ b/pkg/csplugin/utils_windows.go @@ -3,6 +3,7 @@ package csplugin import ( + "errors" "fmt" "os" "os/exec" @@ -77,14 +78,14 @@ func CheckPerms(path string) error { return fmt.Errorf("while getting owner security info: %w", err) } if !sd.IsValid() { - return fmt.Errorf("security descriptor is invalid") + return errors.New("security descriptor is invalid") } owner, _, err := sd.Owner() if err != nil { return fmt.Errorf("while getting owner: %w", err) } if !owner.IsValid() { - return fmt.Errorf("owner is invalid") + return errors.New("owner is invalid") } if !owner.Equals(systemSid) && !owner.Equals(currentUserSid) && !owner.Equals(adminSid) { @@ -100,10 +101,6 @@ func CheckPerms(path string) error { return fmt.Errorf("no DACL found on plugin, meaning fully permissive access on plugin %s", path) } - if err != nil { - return fmt.Errorf("while looking up current user sid: %w", err) - } - rs := reflect.ValueOf(dacl).Elem() /* @@ -119,7 +116,7 @@ func CheckPerms(path string) error { */ aceCount := rs.Field(3).Uint() - for i := uint64(0); i < aceCount; i++ { + for i := range aceCount { ace := &AccessAllowedAce{} ret, _, _ := procGetAce.Call(uintptr(unsafe.Pointer(dacl)), uintptr(i), uintptr(unsafe.Pointer(&ace))) if ret == 0 { diff --git a/pkg/csplugin/utils_windows_test.go b/pkg/csplugin/utils_windows_test.go index 6a76e1215e5..1eb4dfb9033 100644 --- a/pkg/csplugin/utils_windows_test.go +++ b/pkg/csplugin/utils_windows_test.go @@ -37,7 +37,6 @@ func TestGetPluginNameAndTypeFromPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, got1, err := getPluginTypeAndSubtypeFromPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index d0bb7b2f142..84e63ec6493 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -15,11 +15,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -var ctx = context.Background() - func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) { testTomb.Kill(nil) <-pw.PluginEvents + if err := testTomb.Wait(); err != nil { log.Fatal(err) } @@ -34,7 +33,7 @@ func resetWatcherAlertCounter(pw *PluginWatcher) { } func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) { - for i := 0; i < n; i++ { + for range n { pw.Inserts <- pluginName } } @@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error { case <-ctx.Done(): return ctx.Err() } + return nil } func TestPluginWatcherInterval(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) testTomb := tomb.Tomb{} @@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") resetTestTomb(&testTomb, &pw) @@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) @@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) { } func TestPluginAlertCountWatcher(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) configs := map[string]PluginConfig{ @@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) { }, } testTomb := tomb.Tomb{} + pw.Init(configs, alertsByPluginName) pw.Start(&testTomb) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel won't contain any events since threshold is not crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 4, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel will contain an event since threshold is crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 5, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) diff --git a/pkg/csprofiles/csprofiles.go b/pkg/csprofiles/csprofiles.go index 95fbb356f3d..52cda1ed2e1 100644 --- a/pkg/csprofiles/csprofiles.go +++ b/pkg/csprofiles/csprofiles.go @@ -4,8 +4,8 @@ import ( "fmt" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -35,7 +35,7 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { xlog := log.New() if err := types.ConfigureLogger(xlog); err != nil { - log.Fatalf("While creating profiles-specific logger : %s", err) + return nil, fmt.Errorf("while configuring profiles-specific logger: %w", err) } xlog.SetLevel(log.InfoLevel) @@ -196,6 +196,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision decisions = append(decisions, subdecisions...) } else { Profile.Logger.Debugf("Profile %s filter is unsuccessful", Profile.Cfg.Name) + if Profile.Cfg.OnFailure == "break" { break } diff --git a/pkg/csprofiles/csprofiles_test.go b/pkg/csprofiles/csprofiles_test.go index be1d0178e72..0247243ddd3 100644 --- a/pkg/csprofiles/csprofiles_test.go +++ b/pkg/csprofiles/csprofiles_test.go @@ -102,7 +102,6 @@ func TestNewProfile(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { profilesCfg := []*csconfig.ProfileCfg{ test.profileCfg, @@ -196,7 +195,6 @@ func TestEvaluateProfile(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { profilesCfg := []*csconfig.ProfileCfg{ tt.args.profileCfg, diff --git a/pkg/cticlient/client.go b/pkg/cticlient/client.go index 4df4d65a63c..90112d80abf 100644 --- a/pkg/cticlient/client.go +++ b/pkg/cticlient/client.go @@ -9,6 +9,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) const ( @@ -43,7 +45,10 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map if err != nil { return nil, err } - req.Header.Set("x-api-key", c.apiKey) + + req.Header.Set("X-Api-Key", c.apiKey) + req.Header.Set("User-Agent", useragent.Default()) + resp, err := c.httpClient.Do(req) if err != nil { return nil, err diff --git a/pkg/cticlient/client_test.go b/pkg/cticlient/client_test.go index 79406a6c2a9..cdbbd0c9732 100644 --- a/pkg/cticlient/client_test.go +++ b/pkg/cticlient/client_test.go @@ -38,7 +38,7 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { func fireHandler(req *http.Request) *http.Response { var err error - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { log.Warningf("invalid api key: %s", apiKey) @@ -105,7 +105,7 @@ func fireHandler(req *http.Request) *http.Response { } func smokeHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -137,7 +137,7 @@ func smokeHandler(req *http.Request) *http.Response { } func rateLimitedHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -154,7 +154,7 @@ func rateLimitedHandler(req *http.Request) *http.Response { } func searchHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, diff --git a/pkg/cwhub/cwhub.go b/pkg/cwhub/cwhub.go index 9ce091fad39..683f1853b43 100644 --- a/pkg/cwhub/cwhub.go +++ b/pkg/cwhub/cwhub.go @@ -4,13 +4,26 @@ import ( "fmt" "net/http" "path/filepath" - "sort" "strings" "time" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) +// hubTransport wraps a Transport to set a custom User-Agent. +type hubTransport struct { + http.RoundTripper +} + +func (t *hubTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", useragent.Default()) + return t.RoundTripper.RoundTrip(req) +} + +// hubClient is the HTTP client used to communicate with the CrowdSec Hub. var hubClient = &http.Client{ - Timeout: 120 * time.Second, + Timeout: 120 * time.Second, + Transport: &hubTransport{http.DefaultTransport}, } // safePath returns a joined path and ensures that it does not escape the base directory. @@ -31,10 +44,3 @@ func safePath(dir, filePath string) (string, error) { return absFilePath, nil } - -// SortItemSlice sorts a slice of items by name, case insensitive. -func SortItemSlice(items []*Item) { - sort.Slice(items, func(i, j int) bool { - return strings.ToLower(items[i].Name) < strings.ToLower(items[j].Name) - }) -} diff --git a/pkg/cwhub/cwhub_test.go b/pkg/cwhub/cwhub_test.go index 0a1363ebe09..17e7a0dc723 100644 --- a/pkg/cwhub/cwhub_test.go +++ b/pkg/cwhub/cwhub_test.go @@ -1,6 +1,8 @@ package cwhub import ( + "context" + "fmt" "io" "net/http" "os" @@ -14,7 +16,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -const mockURLTemplate = "https://hub-cdn.crowdsec.net/%s/%s" +const mockURLTemplate = "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s" /* To test : @@ -61,7 +63,16 @@ func testHub(t *testing.T, update bool) *Hub { IndexPath: ".index.json", } - hub, err := NewHub(local, remote, update, log.StandardLogger()) + hub, err := NewHub(local, remote, log.StandardLogger()) + require.NoError(t, err) + + if update { + ctx := context.Background() + err := hub.Update(ctx) + require.NoError(t, err) + } + + err = hub.Load() require.NoError(t, err) return hub @@ -107,7 +118,7 @@ func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { // FAKE PARSER resp, ok := responseByPath[req.URL.Path] if !ok { - log.Fatalf("unexpected url :/ %s", req.URL.Path) + return nil, fmt.Errorf("unexpected url: %s", req.URL.Path) } response.Body = io.NopCloser(strings.NewReader(resp)) @@ -132,18 +143,18 @@ func fileToStringX(path string) string { func setResponseByPath() { responseByPath = map[string]string{ - "/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), - "/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), - "/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./testdata/collection_v1.yaml"), - "/master/.index.json": fileToStringX("./testdata/index1.json"), - "/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true + "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), + "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), + "/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./testdata/collection_v1.yaml"), + "/crowdsecurity/master/.index.json": fileToStringX("./testdata/index1.json"), + "/crowdsecurity/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true name: crowdsecurity/foobar_scenario`, - "/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true + "/crowdsecurity/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true name: crowdsecurity/foobar_scenario`, - "/master/collections/crowdsecurity/foobar_subcollection.yaml": ` + "/crowdsecurity/master/collections/crowdsecurity/foobar_subcollection.yaml": ` blah: blalala qwe: jejwejejw`, - "/master/collections/crowdsecurity/foobar.yaml": ` + "/crowdsecurity/master/collections/crowdsecurity/foobar.yaml": ` blah: blalala qwe: jejwejejw`, } diff --git a/pkg/cwhub/dataset.go b/pkg/cwhub/dataset.go index c900752b8b3..90bc9e057f9 100644 --- a/pkg/cwhub/dataset.go +++ b/pkg/cwhub/dataset.go @@ -1,16 +1,17 @@ package cwhub import ( + "context" "errors" "fmt" "io" - "net/http" - "os" "time" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/go-cs-lib/downloader" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,97 +20,8 @@ type DataSet struct { Data []types.DataSource `yaml:"data,omitempty"` } -// downloadFile downloads a file and writes it to disk, with no hash verification. -func downloadFile(url string, destPath string) error { - resp, err := hubClient.Get(url) - if err != nil { - return fmt.Errorf("while downloading %s: %w", url, err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) - } - - file, err := os.Create(destPath) - if err != nil { - return err - } - defer file.Close() - - // avoid reading the whole file in memory - _, err = io.Copy(file, resp.Body) - if err != nil { - return err - } - - if err = file.Sync(); err != nil { - return err - } - - return nil -} - -// needsUpdate checks if a data file has to be downloaded (or updated). -// if the local file doesn't exist, update. -// if the remote is newer than the local file, update. -// if the remote has no modification date, but local file has been modified > a week ago, update. -func needsUpdate(destPath string, url string, logger *logrus.Logger) bool { - fileInfo, err := os.Stat(destPath) - - switch { - case os.IsNotExist(err): - return true - case err != nil: - logger.Errorf("while getting %s: %s", destPath, err) - return true - } - - resp, err := hubClient.Head(url) - if err != nil { - logger.Errorf("while getting %s: %s", url, err) - // Head failed, Get would likely fail too -> no update - return false - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - logger.Errorf("bad http code %d for %s", resp.StatusCode, url) - return false - } - - // update if local file is older than this - shelfLife := 7 * 24 * time.Hour - - lastModify := fileInfo.ModTime() - - localIsOld := lastModify.Add(shelfLife).Before(time.Now()) - - remoteLastModified := resp.Header.Get("Last-Modified") - if remoteLastModified == "" { - if localIsOld { - logger.Infof("no last modified date for %s, but local file is older than %s", url, shelfLife) - } - - return localIsOld - } - - lastAvailable, err := time.Parse(time.RFC1123, remoteLastModified) - if err != nil { - logger.Warningf("while parsing last modified date for %s: %s", url, err) - return localIsOld - } - - if lastModify.Before(lastAvailable) { - logger.Infof("new version available, updating %s", destPath) - return true - } - - return false -} - // downloadDataSet downloads all the data files for an item. -func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error { +func downloadDataSet(ctx context.Context, dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error { dec := yaml.NewDecoder(reader) for { @@ -129,12 +41,29 @@ func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *lo return err } - if force || needsUpdate(destPath, dataS.SourceURL, logger) { - logger.Debugf("downloading %s in %s", dataS.SourceURL, destPath) + d := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + CompareContent(). + WithLogger(logrus.WithField("url", dataS.SourceURL)) + + if !force { + d = d.WithLastModified(). + WithShelfLife(7 * 24 * time.Hour) + } + + downloaded, err := d.Download(ctx, dataS.SourceURL) + if err != nil { + return fmt.Errorf("while getting data: %w", err) + } - if err := downloadFile(dataS.SourceURL, destPath); err != nil { - return fmt.Errorf("while getting data: %w", err) - } + if downloaded { + logger.Infof("Downloaded %s", destPath) + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("updated %s\n", destPath) } } } diff --git a/pkg/cwhub/dataset_test.go b/pkg/cwhub/dataset_test.go deleted file mode 100644 index f23f4878285..00000000000 --- a/pkg/cwhub/dataset_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package cwhub - -import ( - "os" - "testing" - - "github.com/jarcoal/httpmock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDownloadFile(t *testing.T) { - examplePath := "./example.txt" - defer os.Remove(examplePath) - - httpmock.Activate() - defer httpmock.DeactivateAndReset() - - //OK - httpmock.RegisterResponder( - "GET", - "https://example.com/xx", - httpmock.NewStringResponder(200, "example content oneoneone"), - ) - - httpmock.RegisterResponder( - "GET", - "https://example.com/x", - httpmock.NewStringResponder(404, "not found"), - ) - - err := downloadFile("https://example.com/xx", examplePath) - require.NoError(t, err) - - content, err := os.ReadFile(examplePath) - assert.Equal(t, "example content oneoneone", string(content)) - require.NoError(t, err) - - //bad uri - err = downloadFile("https://zz.com", examplePath) - require.Error(t, err) - - //404 - err = downloadFile("https://example.com/x", examplePath) - require.Error(t, err) - - //bad target - err = downloadFile("https://example.com/xx", "") - require.Error(t, err) -} diff --git a/pkg/cwhub/doc.go b/pkg/cwhub/doc.go index 85767265048..f86b95c6454 100644 --- a/pkg/cwhub/doc.go +++ b/pkg/cwhub/doc.go @@ -2,10 +2,10 @@ // // # Definitions // -// - A hub ITEM is a file that defines a parser, a scenario, a collection... in the case of a collection, it has dependencies on other hub items. -// - The hub INDEX is a JSON file that contains a tree of available hub items. -// - A REMOTE HUB is an HTTP server that hosts the hub index and the hub items. It can serve from several branches, usually linked to the CrowdSec version. -// - A LOCAL HUB is a directory that contains a copy of the hub index and the downloaded hub items. +// - A hub ITEM is a file that defines a parser, a scenario, a collection... in the case of a collection, it has dependencies on other hub items. +// - The hub INDEX is a JSON file that contains a tree of available hub items. +// - A REMOTE HUB is an HTTP server that hosts the hub index and the hub items. It can serve from several branches, usually linked to the CrowdSec version. +// - A LOCAL HUB is a directory that contains a copy of the hub index and the downloaded hub items. // // Once downloaded, hub items can be installed by linking to them from the configuration directory. // If an item is present in the configuration directory but it's not a link to the local hub, it is @@ -17,15 +17,15 @@ // // For the local hub (HubDir = /etc/crowdsec/hub): // -// - /etc/crowdsec/hub/.index.json -// - /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml -// - /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml +// - /etc/crowdsec/hub/.index.json +// - /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml +// - /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml // // For the configuration directory (InstallDir = /etc/crowdsec): // -// - /etc/crowdsec/parsers/{stage}/{parser-name.yaml} -> /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml -// - /etc/crowdsec/scenarios/{scenario-name.yaml} -> /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml -// - /etc/crowdsec/scenarios/local-scenario.yaml +// - /etc/crowdsec/parsers/{stage}/{parser-name.yaml} -> /etc/crowdsec/hub/parsers/{stage}/{author}/{parser-name}.yaml +// - /etc/crowdsec/scenarios/{scenario-name.yaml} -> /etc/crowdsec/hub/scenarios/{author}/{scenario-name}.yaml +// - /etc/crowdsec/scenarios/local-scenario.yaml // // Note that installed items are not grouped by author, this may change in the future if we want to // support items with the same name from different authors. @@ -35,11 +35,10 @@ // Additionally, an item can reference a DATA SET that is installed in a different location than // the item itself. These files are stored in the data directory (InstallDataDir = /var/lib/crowdsec/data). // -// - /var/lib/crowdsec/data/http_path_traversal.txt -// - /var/lib/crowdsec/data/jira_cve_2021-26086.txt -// - /var/lib/crowdsec/data/log4j2_cve_2021_44228.txt -// - /var/lib/crowdsec/data/sensitive_data.txt -// +// - /var/lib/crowdsec/data/http_path_traversal.txt +// - /var/lib/crowdsec/data/jira_cve_2021-26086.txt +// - /var/lib/crowdsec/data/log4j2_cve_2021_44228.txt +// - /var/lib/crowdsec/data/sensitive_data.txt // // # Using the package // @@ -58,15 +57,24 @@ // InstallDir: "/etc/crowdsec", // InstallDataDir: "/var/lib/crowdsec/data", // } -// hub, err := cwhub.NewHub(localHub, nil, false) +// +// hub, err := cwhub.NewHub(localHub, nil, logger) // if err != nil { // return fmt.Errorf("unable to initialize hub: %w", err) // } // -// Now you can use the hub to access the existing items: +// If the logger is nil, the item-by-item messages will be discarded, including warnings. +// After configuring the hub, you must sync its state with items on disk. +// +// err := hub.Load() +// if err != nil { +// return fmt.Errorf("unable to load hub: %w", err) +// } +// +// Now you can use the hub object to access the existing items: // // // list all the parsers -// for _, parser := range hub.GetItemMap(cwhub.PARSERS) { +// for _, parser := range hub.GetItemsByType(cwhub.PARSERS, false) { // fmt.Printf("parser: %s\n", parser.Name) // } // @@ -78,13 +86,13 @@ // // You can also install items if they have already been downloaded: // -// // install a parser -// force := false -// downloadOnly := false -// err := parser.Install(force, downloadOnly) -// if err != nil { -// return fmt.Errorf("unable to install parser: %w", err) -// } +// // install a parser +// force := false +// downloadOnly := false +// err := parser.Install(force, downloadOnly) +// if err != nil { +// return fmt.Errorf("unable to install parser: %w", err) +// } // // As soon as you try to install an item that is not downloaded or is not up-to-date (meaning its computed hash // does not correspond to the latest version available in the index), a download will be attempted and you'll @@ -92,13 +100,13 @@ // // To provide the remote hub configuration, use the second parameter of NewHub(): // -// remoteHub := cwhub.RemoteHubCfg{ -// URLTemplate: "https://hub-cdn.crowdsec.net/%s/%s", +// remoteHub := cwhub.RemoteHubCfg{ +// URLTemplate: "https://cdn-hub.crowdsec.net/crowdsecurity/%s/%s", // Branch: "master", // IndexPath: ".index.json", // } -// updateIndex := false -// hub, err := cwhub.NewHub(localHub, remoteHub, updateIndex) +// +// hub, err := cwhub.NewHub(localHub, remoteHub, logger) // if err != nil { // return fmt.Errorf("unable to initialize hub: %w", err) // } @@ -106,8 +114,13 @@ // The URLTemplate is a string that will be used to build the URL of the remote hub. It must contain two // placeholders: the branch and the file path (it will be an index or an item). // -// Setting the third parameter to true will download the latest version of the index, if available on the -// specified branch. -// There is no exported method to update the index once the hub struct is created. +// Before calling hub.Load(), you can update the index file by calling the Update() method: +// +// err := hub.Update(context.Background()) +// if err != nil { +// return fmt.Errorf("unable to update hub index: %w", err) +// } // +// Note that the command will fail if the hub has already been synced. If you want to do it (ex. after a configuration +// change the application is notified with SIGHUP) you have to instantiate a new hub object and dispose of the old one. package cwhub diff --git a/pkg/cwhub/errors.go b/pkg/cwhub/errors.go index 789c2eced7b..b0be444fcba 100644 --- a/pkg/cwhub/errors.go +++ b/pkg/cwhub/errors.go @@ -5,10 +5,8 @@ import ( "fmt" ) -var ( - // ErrNilRemoteHub is returned when the remote hub configuration is not provided to the NewHub constructor. - ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") -) +// ErrNilRemoteHub is returned when trying to download with a local-only configuration. +var ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") // IndexNotFoundError is returned when the remote hub index is not found. type IndexNotFoundError struct { diff --git a/pkg/cwhub/hub.go b/pkg/cwhub/hub.go index 21a19bc4526..f74a794a512 100644 --- a/pkg/cwhub/hub.go +++ b/pkg/cwhub/hub.go @@ -1,27 +1,30 @@ package cwhub import ( - "bytes" + "context" "encoding/json" + "errors" "fmt" "io" "os" "path" - "slices" "strings" "github.com/sirupsen/logrus" + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) // Hub is the main structure for the package. type Hub struct { - items HubItems // Items read from HubDir and InstallDir - local *csconfig.LocalHubCfg - remote *RemoteHubCfg - Warnings []string // Warnings encountered during sync - logger *logrus.Logger + items HubItems // Items read from HubDir and InstallDir + pathIndex map[string]*Item + local *csconfig.LocalHubCfg + remote *RemoteHubCfg + logger *logrus.Logger + Warnings []string // Warnings encountered during sync } // GetDataDir returns the data directory, where data sets are installed. @@ -29,12 +32,13 @@ func (h *Hub) GetDataDir() string { return h.local.InstallDataDir } -// NewHub returns a new Hub instance with local and (optionally) remote configuration, and syncs the local state. -// If updateIndex is true, the local index file is updated from the remote before reading the state of the items. +// NewHub returns a new Hub instance with local and (optionally) remote configuration. +// The hub is not synced automatically. Load() must be called to read the index, sync the local state, +// and check for unmanaged items. // All download operations (including updateIndex) return ErrNilRemoteHub if the remote configuration is not set. -func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, updateIndex bool, logger *logrus.Logger) (*Hub, error) { +func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, logger *logrus.Logger) (*Hub, error) { if local == nil { - return nil, fmt.Errorf("no hub configuration found") + return nil, errors.New("no hub configuration found") } if logger == nil { @@ -43,28 +47,28 @@ func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, updateIndex bool, } hub := &Hub{ - local: local, - remote: remote, - logger: logger, + local: local, + remote: remote, + logger: logger, + pathIndex: make(map[string]*Item, 0), } - if updateIndex { - if err := hub.updateIndex(); err != nil { - return nil, err - } - } + return hub, nil +} - logger.Debugf("loading hub idx %s", local.HubIndexFile) +// Load reads the state of the items on disk. +func (h *Hub) Load() error { + h.logger.Debugf("loading hub idx %s", h.local.HubIndexFile) - if err := hub.parseIndex(); err != nil { - return nil, fmt.Errorf("failed to load index: %w", err) + if err := h.parseIndex(); err != nil { + return fmt.Errorf("failed to load hub index: %w", err) } - if err := hub.localSync(); err != nil { - return nil, fmt.Errorf("failed to sync items: %w", err) + if err := h.localSync(); err != nil { + return fmt.Errorf("failed to sync hub items: %w", err) } - return hub, nil + return nil } // parseIndex takes the content of an index file and fills the map of associated parsers/scenarios/collections. @@ -75,7 +79,7 @@ func (h *Hub) parseIndex() error { } if err := json.Unmarshal(bidx, &h.items); err != nil { - return fmt.Errorf("failed to unmarshal index: %w", err) + return fmt.Errorf("failed to parse index: %w", err) } h.logger.Debugf("%d item types in hub index", len(ItemTypes)) @@ -114,13 +118,14 @@ func (h *Hub) ItemStats() []string { tainted := 0 for _, itemType := range ItemTypes { - if len(h.GetItemMap(itemType)) == 0 { + items := h.GetItemsByType(itemType, false) + if len(items) == 0 { continue } - loaded += fmt.Sprintf("%d %s, ", len(h.GetItemMap(itemType)), itemType) + loaded += fmt.Sprintf("%d %s, ", len(items), itemType) - for _, item := range h.GetItemMap(itemType) { + for _, item := range items { if item.State.IsLocal() { local++ } @@ -137,7 +142,7 @@ func (h *Hub) ItemStats() []string { } ret := []string{ - fmt.Sprintf("Loaded: %s", loaded), + "Loaded: " + loaded, } if local > 0 || tainted > 0 { @@ -147,29 +152,25 @@ func (h *Hub) ItemStats() []string { return ret } -// updateIndex downloads the latest version of the index and writes it to disk if it changed. -func (h *Hub) updateIndex() error { - body, err := h.remote.fetchIndex() - if err != nil { - return err +// Update downloads the latest version of the index and writes it to disk if it changed. It cannot be called after Load() +// unless the hub is completely empty. +func (h *Hub) Update(ctx context.Context) error { + if len(h.pathIndex) > 0 { + // if this happens, it's a bug. + return errors.New("cannot update hub after items have been loaded") } - oldContent, err := os.ReadFile(h.local.HubIndexFile) + downloaded, err := h.remote.fetchIndex(ctx, h.local.HubIndexFile) if err != nil { - if !os.IsNotExist(err) { - h.logger.Warningf("failed to read hub index: %s", err) - } - } else if bytes.Equal(body, oldContent) { - h.logger.Info("hub index is up to date") - return nil + return err } - if err = os.WriteFile(h.local.HubIndexFile, body, 0o644); err != nil { - return fmt.Errorf("failed to write hub index: %w", err) + if downloaded { + h.logger.Infof("Wrote index to %s", h.local.HubIndexFile) + } else { + h.logger.Info("hub index is up to date") } - h.logger.Infof("Wrote index to %s, %d bytes", h.local.HubIndexFile, len(body)) - return nil } @@ -179,6 +180,7 @@ func (h *Hub) addItem(item *Item) { } h.items[item.Type][item.Name] = item + h.pathIndex[item.State.LocalPath] = item } // GetItemMap returns the map of items for a given type. @@ -191,6 +193,11 @@ func (h *Hub) GetItem(itemType string, itemName string) *Item { return h.GetItemMap(itemType)[itemName] } +// GetItemByPath returns an item from hub based on its (absolute) local path. +func (h *Hub) GetItemByPath(itemPath string) *Item { + return h.pathIndex[itemPath] +} + // GetItemFQ returns an item from hub based on its type and name (type:author/name). func (h *Hub) GetItemFQ(itemFQName string) (*Item, error) { // type and name are separated by a colon @@ -213,73 +220,62 @@ func (h *Hub) GetItemFQ(itemFQName string) (*Item, error) { return i, nil } -// GetItemNames returns a slice of (full) item names for a given type -// (eg. for collections: crowdsecurity/apache2 crowdsecurity/nginx). -func (h *Hub) GetItemNames(itemType string) []string { - m := h.GetItemMap(itemType) - if m == nil { - return nil - } +// GetItemsByType returns a slice of all the items of a given type, installed or not, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetItemsByType(itemType string, sorted bool) []*Item { + items := h.items[itemType] - names := make([]string, 0, len(m)) - for k := range m { - names = append(names, k) - } + ret := make([]*Item, len(items)) - return names -} + if sorted { + for idx, name := range maptools.SortedKeysNoCase(items) { + ret[idx] = items[name] + } -// GetAllItems returns a slice of all the items of a given type, installed or not. -func (h *Hub) GetAllItems(itemType string) ([]*Item, error) { - if !slices.Contains(ItemTypes, itemType) { - return nil, fmt.Errorf("invalid item type %s", itemType) + return ret } - items := h.items[itemType] - - ret := make([]*Item, len(items)) - idx := 0 - for _, item := range items { ret[idx] = item - idx++ + idx += 1 } - return ret, nil + return ret } -// GetInstalledItems returns a slice of the installed items of a given type. -func (h *Hub) GetInstalledItems(itemType string) ([]*Item, error) { - if !slices.Contains(ItemTypes, itemType) { - return nil, fmt.Errorf("invalid item type %s", itemType) - } - - items := h.items[itemType] - - retItems := make([]*Item, 0) +// GetInstalledByType returns a slice of all the installed items of a given type, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetInstalledByType(itemType string, sorted bool) []*Item { + ret := make([]*Item, 0) - for _, item := range items { + for _, item := range h.GetItemsByType(itemType, sorted) { if item.State.Installed { - retItems = append(retItems, item) + ret = append(ret, item) } } - return retItems, nil + return ret } -// GetInstalledItemNames returns the names of the installed items of a given type. -func (h *Hub) GetInstalledItemNames(itemType string) ([]string, error) { - items, err := h.GetInstalledItems(itemType) - if err != nil { - return nil, err - } +// GetInstalledListForAPI returns a slice of names of all the installed scenarios and appsec-rules. +// The returned list is sorted by type (scenarios first) and case-insensitive name. +func (h *Hub) GetInstalledListForAPI() []string { + scenarios := h.GetInstalledByType(SCENARIOS, true) + appsecRules := h.GetInstalledByType(APPSEC_RULES, true) - retStr := make([]string, len(items)) + ret := make([]string, len(scenarios)+len(appsecRules)) - for idx, it := range items { - retStr[idx] = it.Name + idx := 0 + for _, item := range scenarios { + ret[idx] = item.Name + idx += 1 } - return retStr, nil + for _, item := range appsecRules { + ret[idx] = item.Name + idx += 1 + } + + return ret } diff --git a/pkg/cwhub/hub_test.go b/pkg/cwhub/hub_test.go index 86569cde324..1c2c9ccceca 100644 --- a/pkg/cwhub/hub_test.go +++ b/pkg/cwhub/hub_test.go @@ -1,6 +1,7 @@ package cwhub import ( + "context" "fmt" "os" "testing" @@ -18,7 +19,15 @@ func TestInitHubUpdate(t *testing.T) { IndexPath: ".index.json", } - _, err := NewHub(hub.local, remote, true, nil) + _, err := NewHub(hub.local, remote, nil) + require.NoError(t, err) + + ctx := context.Background() + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() require.NoError(t, err) } @@ -29,6 +38,10 @@ func TestUpdateIndex(t *testing.T) { tmpIndex, err := os.CreateTemp("", "index.json") require.NoError(t, err) + // close the file to avoid preventing the rename on windows + err = tmpIndex.Close() + require.NoError(t, err) + t.Cleanup(func() { os.Remove(tmpIndex.Name()) }) @@ -43,19 +56,21 @@ func TestUpdateIndex(t *testing.T) { hub.local.HubIndexFile = tmpIndex.Name() - err = hub.updateIndex() + ctx := context.Background() + + err = hub.Update(ctx) cstest.RequireErrorContains(t, err, "failed to build hub index request: invalid URL template 'x'") // bad domain fmt.Println("Test 'bad domain'") hub.remote = &RemoteHubCfg{ - URLTemplate: "https://baddomain/%s/%s", + URLTemplate: "https://baddomain/crowdsecurity/%s/%s", Branch: "master", IndexPath: ".index.json", } - err = hub.updateIndex() + err = hub.Update(ctx) require.NoError(t, err) // XXX: this is not failing // cstest.RequireErrorContains(t, err, "failed http request for hub index: Get") @@ -71,6 +86,6 @@ func TestUpdateIndex(t *testing.T) { hub.local.HubIndexFile = "/does/not/exist/index.json" - err = hub.updateIndex() - cstest.RequireErrorContains(t, err, "failed to write hub index: open /does/not/exist/index.json:") + err = hub.Update(ctx) + cstest.RequireErrorContains(t, err, "failed to create temporary download file for /does/not/exist/index.json:") } diff --git a/pkg/cwhub/item.go b/pkg/cwhub/item.go index 6c7da06c313..32d1acf94ff 100644 --- a/pkg/cwhub/item.go +++ b/pkg/cwhub/item.go @@ -7,7 +7,8 @@ import ( "slices" "github.com/Masterminds/semver/v3" - "github.com/enescakir/emoji" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" ) const ( @@ -28,10 +29,8 @@ const ( versionFuture // local version is higher latest, but is included in the index: should not happen ) -var ( - // The order is important, as it is used to range over sub-items in collections. - ItemTypes = []string{PARSERS, POSTOVERFLOWS, SCENARIOS, CONTEXTS, APPSEC_CONFIGS, APPSEC_RULES, COLLECTIONS} -) +// The order is important, as it is used to range over sub-items in collections. +var ItemTypes = []string{PARSERS, POSTOVERFLOWS, SCENARIOS, CONTEXTS, APPSEC_CONFIGS, APPSEC_RULES, COLLECTIONS} type HubItems map[string]map[string]*Item @@ -84,7 +83,7 @@ func (s *ItemState) Text() string { } // Emoji returns the status of the item as an emoji (eg. emoji.Warning). -func (s *ItemState) Emoji() emoji.Emoji { +func (s *ItemState) Emoji() string { switch { case s.IsLocal(): return emoji.House @@ -110,6 +109,7 @@ type Item struct { Name string `json:"name,omitempty" yaml:"name,omitempty"` // usually "author/name" FileName string `json:"file_name,omitempty" yaml:"file_name,omitempty"` // eg. apache2-logs.yaml Description string `json:"description,omitempty" yaml:"description,omitempty"` + Content string `json:"content,omitempty" yaml:"-"` Author string `json:"author,omitempty" yaml:"author,omitempty"` References []string `json:"references,omitempty" yaml:"references,omitempty"` diff --git a/pkg/cwhub/iteminstall.go b/pkg/cwhub/iteminstall.go index ceae3649118..912897d0d7e 100644 --- a/pkg/cwhub/iteminstall.go +++ b/pkg/cwhub/iteminstall.go @@ -1,6 +1,7 @@ package cwhub import ( + "context" "fmt" ) @@ -8,11 +9,11 @@ import ( func (i *Item) enable() error { if i.State.Installed { if i.State.Tainted { - return fmt.Errorf("%s is tainted, won't enable unless --force", i.Name) + return fmt.Errorf("%s is tainted, won't overwrite unless --force", i.Name) } if i.State.IsLocal() { - return fmt.Errorf("%s is local, won't enable", i.Name) + return fmt.Errorf("%s is local, won't overwrite", i.Name) } // if it's a collection, check sub-items even if the collection file itself is up-to-date @@ -39,7 +40,7 @@ func (i *Item) enable() error { } // Install installs the item from the hub, downloading it if needed. -func (i *Item) Install(force bool, downloadOnly bool) error { +func (i *Item) Install(ctx context.Context, force bool, downloadOnly bool) error { if downloadOnly && i.State.Downloaded && i.State.UpToDate { i.hub.logger.Infof("%s is already downloaded and up-to-date", i.Name) @@ -48,13 +49,12 @@ func (i *Item) Install(force bool, downloadOnly bool) error { } } - filePath, err := i.downloadLatest(force, true) + downloaded, err := i.downloadLatest(ctx, force, true) if err != nil { return err } - if downloadOnly { - i.hub.logger.Infof("Downloaded %s to %s", i.Name, filePath) + if downloadOnly && downloaded { return nil } @@ -62,6 +62,11 @@ func (i *Item) Install(force bool, downloadOnly bool) error { return fmt.Errorf("while enabling %s: %w", i.Name, err) } + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("installed %s\n", i.Name) + i.hub.logger.Infof("Enabled %s", i.Name) return nil diff --git a/pkg/cwhub/iteminstall_test.go b/pkg/cwhub/iteminstall_test.go index 80a419ec5da..5bfc7e8148e 100644 --- a/pkg/cwhub/iteminstall_test.go +++ b/pkg/cwhub/iteminstall_test.go @@ -1,6 +1,7 @@ package cwhub import ( + "context" "os" "testing" @@ -9,8 +10,10 @@ import ( ) func testInstall(hub *Hub, t *testing.T, item *Item) { + ctx := context.Background() + // Install the parser - _, err := item.downloadLatest(false, false) + _, err := item.downloadLatest(ctx, false, false) require.NoError(t, err, "failed to download %s", item.Name) err = hub.localSync() @@ -35,7 +38,8 @@ func testTaint(hub *Hub, t *testing.T, item *Item) { // truncate the file f, err := os.Create(item.State.LocalPath) require.NoError(t, err) - f.Close() + err = f.Close() + require.NoError(t, err) // Local sync and check status err = hub.localSync() @@ -47,8 +51,10 @@ func testTaint(hub *Hub, t *testing.T, item *Item) { func testUpdate(hub *Hub, t *testing.T, item *Item) { assert.False(t, item.State.UpToDate, "%s should not be up-to-date", item.Name) + ctx := context.Background() + // Update it + check status - _, err := item.downloadLatest(true, true) + _, err := item.downloadLatest(ctx, true, true) require.NoError(t, err, "failed to update %s", item.Name) // Local sync and check status diff --git a/pkg/cwhub/itemupgrade.go b/pkg/cwhub/itemupgrade.go index ac3b94f9836..105e5ebec31 100644 --- a/pkg/cwhub/itemupgrade.go +++ b/pkg/cwhub/itemupgrade.go @@ -3,23 +3,24 @@ package cwhub // Install, upgrade and remove items from the hub to the local configuration import ( - "bytes" - "crypto/sha256" + "context" + "crypto" + "encoding/base64" "encoding/hex" "errors" "fmt" - "io" - "net/http" "os" "path/filepath" - "github.com/enescakir/emoji" + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/downloader" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" ) // Upgrade downloads and applies the last version of the item from the hub. -func (i *Item) Upgrade(force bool) (bool, error) { - updated := false - +func (i *Item) Upgrade(ctx context.Context, force bool) (bool, error) { if i.State.IsLocal() { i.hub.logger.Infof("not upgrading %s: local item", i.Name) return false, nil @@ -36,7 +37,7 @@ func (i *Item) Upgrade(force bool) (bool, error) { if i.State.UpToDate { i.hub.logger.Infof("%s: up-to-date", i.Name) - if err := i.DownloadDataIfNeeded(force); err != nil { + if err := i.DownloadDataIfNeeded(ctx, force); err != nil { return false, fmt.Errorf("%s: download failed: %w", i.Name, err) } @@ -46,7 +47,7 @@ func (i *Item) Upgrade(force bool) (bool, error) { } } - if _, err := i.downloadLatest(force, true); err != nil { + if _, err := i.downloadLatest(ctx, force, true); err != nil { return false, fmt.Errorf("%s: download failed: %w", i.Name, err) } @@ -54,20 +55,21 @@ func (i *Item) Upgrade(force bool) (bool, error) { if i.State.Tainted { i.hub.logger.Warningf("%v %s is tainted, --force to overwrite", emoji.Warning, i.Name) } - } else { - // a check on stdout is used while scripting to know if the hub has been upgraded - // and a configuration reload is required - // TODO: use a better way to communicate this - fmt.Printf("updated %s\n", i.Name) - i.hub.logger.Infof("%v %s: updated", emoji.Package, i.Name) - updated = true + + return false, nil } - return updated, nil + // a check on stdout is used while scripting to know if the hub has been upgraded + // and a configuration reload is required + // TODO: use a better way to communicate this + fmt.Printf("updated %s\n", i.Name) + i.hub.logger.Infof("%v %s: updated", emoji.Package, i.Name) + + return true, nil } // downloadLatest downloads the latest version of the item to the hub directory. -func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { +func (i *Item) downloadLatest(ctx context.Context, overwrite bool, updateOnly bool) (bool, error) { i.hub.logger.Debugf("Downloading %s %s", i.Type, i.Name) for _, sub := range i.SubItems() { @@ -82,99 +84,118 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { if sub.HasSubItems() { i.hub.logger.Tracef("collection, recurse") - if _, err := sub.downloadLatest(overwrite, updateOnly); err != nil { - return "", err + if _, err := sub.downloadLatest(ctx, overwrite, updateOnly); err != nil { + return false, err } } downloaded := sub.State.Downloaded - if _, err := sub.download(overwrite); err != nil { - return "", err + if _, err := sub.download(ctx, overwrite); err != nil { + return false, err } // We need to enable an item when it has been added to a collection since latest release of the collection. // We check if sub.Downloaded is false because maybe the item has been disabled by the user. if !sub.State.Installed && !downloaded { if err := sub.enable(); err != nil { - return "", fmt.Errorf("enabling '%s': %w", sub.Name, err) + return false, fmt.Errorf("enabling '%s': %w", sub.Name, err) } } } if !i.State.Installed && updateOnly && i.State.Downloaded && !overwrite { i.hub.logger.Debugf("skipping upgrade of %s: not installed", i.Name) - return "", nil - } - - ret, err := i.download(overwrite) - if err != nil { - return "", err + return false, nil } - return ret, nil + return i.download(ctx, overwrite) } -// FetchLatest downloads the latest item from the hub, verifies the hash and returns the content and the used url. -func (i *Item) FetchLatest() ([]byte, string, error) { - if i.latestHash() == "" { - return nil, "", errors.New("latest hash missing from index") +// FetchContentTo downloads the last version of the item's YAML file to the specified path. +func (i *Item) FetchContentTo(ctx context.Context, destPath string) (bool, string, error) { + wantHash := i.latestHash() + if wantHash == "" { + return false, "", errors.New("latest hash missing from index. The index file is invalid, please run 'cscli hub update' and try again") } - url, err := i.hub.remote.urlTo(i.RemotePath) - if err != nil { - return nil, "", fmt.Errorf("failed to build request: %w", err) - } + // Use the embedded content if available + if i.Content != "" { + // the content was historically base64 encoded + content, err := base64.StdEncoding.DecodeString(i.Content) + if err != nil { + content = []byte(i.Content) + } - resp, err := hubClient.Get(url) - if err != nil { - return nil, "", err - } - defer resp.Body.Close() + dir := filepath.Dir(destPath) + + if err := os.MkdirAll(dir, 0o755); err != nil { + return false, "", fmt.Errorf("while creating %s: %w", dir, err) + } + + // check sha256 + hash := crypto.SHA256.New() + if _, err := hash.Write(content); err != nil { + return false, "", fmt.Errorf("while hashing %s: %w", i.Name, err) + } + + gotHash := hex.EncodeToString(hash.Sum(nil)) + if gotHash != wantHash { + return false, "", fmt.Errorf("hash mismatch: expected %s, got %s. The index file is invalid, please run 'cscli hub update' and try again", wantHash, gotHash) + } + + if err := os.WriteFile(destPath, content, 0o600); err != nil { + return false, "", fmt.Errorf("while writing %s: %w", destPath, err) + } - if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("bad http code %d", resp.StatusCode) + i.hub.logger.Debugf("Wrote %s content from .index.json to %s", i.Name, destPath) + + return true, fmt.Sprintf("(embedded in %s)", i.hub.local.HubIndexFile), nil } - body, err := io.ReadAll(resp.Body) + url, err := i.hub.remote.urlTo(i.RemotePath) if err != nil { - return nil, "", err + return false, "", fmt.Errorf("failed to build request: %w", err) } - hash := sha256.New() - if _, err = hash.Write(body); err != nil { - return nil, "", fmt.Errorf("while hashing %s: %w", i.Name, err) - } + d := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + WithETagFn(downloader.SHA256). + WithMakeDirs(true). + WithLogger(logrus.WithField("url", url)). + CompareContent(). + VerifyHash("sha256", wantHash) - meow := hex.EncodeToString(hash.Sum(nil)) - if meow != i.Versions[i.Version].Digest { - i.hub.logger.Errorf("Downloaded version doesn't match index, please 'hub update'") - i.hub.logger.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest) + // TODO: recommend hub update if hash does not match - return nil, "", fmt.Errorf("invalid download hash") + downloaded, err := d.Download(ctx, url) + if err != nil { + return false, "", err } - return body, url, nil + return downloaded, url, nil } // download downloads the item from the hub and writes it to the hub directory. -func (i *Item) download(overwrite bool) (string, error) { +func (i *Item) download(ctx context.Context, overwrite bool) (bool, error) { // ensure that target file is within target dir finalPath, err := i.downloadPath() if err != nil { - return "", err + return false, err } if i.State.IsLocal() { i.hub.logger.Warningf("%s is local, can't download", i.Name) - return finalPath, nil + return false, nil } // if user didn't --force, don't overwrite local, tainted, up-to-date files if !overwrite { if i.State.Tainted { i.hub.logger.Debugf("%s: tainted, not updated", i.Name) - return "", nil + return false, nil } if i.State.UpToDate { @@ -183,49 +204,36 @@ func (i *Item) download(overwrite bool) (string, error) { } } - body, url, err := i.FetchLatest() + downloaded, _, err := i.FetchContentTo(ctx, finalPath) if err != nil { - what := i.Name - if url != "" { - what += " from " + url - } - - return "", fmt.Errorf("while downloading %s: %w", what, err) + return false, err } - // all good, install - - parentDir := filepath.Dir(finalPath) - - if err = os.MkdirAll(parentDir, os.ModePerm); err != nil { - return "", fmt.Errorf("while creating %s: %w", parentDir, err) - } - - // check actual file - if _, err = os.Stat(finalPath); !os.IsNotExist(err) { - i.hub.logger.Warningf("%s: overwrite", i.Name) - i.hub.logger.Debugf("target: %s", finalPath) - } else { - i.hub.logger.Infof("%s: OK", i.Name) - } - - if err = os.WriteFile(finalPath, body, 0o644); err != nil { - return "", fmt.Errorf("while writing %s: %w", finalPath, err) + if downloaded { + i.hub.logger.Infof("Downloaded %s", i.Name) } i.State.Downloaded = true i.State.Tainted = false i.State.UpToDate = true - if err = downloadDataSet(i.hub.local.InstallDataDir, overwrite, bytes.NewReader(body), i.hub.logger); err != nil { - return "", fmt.Errorf("while downloading data for %s: %w", i.FileName, err) + // read content to get the list of data files + reader, err := os.Open(finalPath) + if err != nil { + return false, fmt.Errorf("while opening %s: %w", finalPath, err) + } + + defer reader.Close() + + if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, overwrite, reader, i.hub.logger); err != nil { + return false, fmt.Errorf("while downloading data for %s: %w", i.FileName, err) } - return finalPath, nil + return true, nil } // DownloadDataIfNeeded downloads the data set for the item. -func (i *Item) DownloadDataIfNeeded(force bool) error { +func (i *Item) DownloadDataIfNeeded(ctx context.Context, force bool) error { itemFilePath, err := i.installPath() if err != nil { return err @@ -238,7 +246,7 @@ func (i *Item) DownloadDataIfNeeded(force bool) error { defer itemFile.Close() - if err = downloadDataSet(i.hub.local.InstallDataDir, force, itemFile, i.hub.logger); err != nil { + if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, force, itemFile, i.hub.logger); err != nil { return fmt.Errorf("while downloading data for %s: %w", itemFilePath, err) } diff --git a/pkg/cwhub/itemupgrade_test.go b/pkg/cwhub/itemupgrade_test.go index 1bd62ad63e8..5f9e4d1944e 100644 --- a/pkg/cwhub/itemupgrade_test.go +++ b/pkg/cwhub/itemupgrade_test.go @@ -1,6 +1,7 @@ package cwhub import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -18,7 +19,9 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) { require.False(t, item.State.Downloaded) require.False(t, item.State.Installed) - require.NoError(t, item.Install(false, false)) + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) require.True(t, item.State.Downloaded) require.True(t, item.State.Installed) @@ -39,8 +42,14 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) { IndexPath: ".index.json", } - hub, err := NewHub(hub.local, remote, true, nil) - require.NoError(t, err, "failed to download index: %s", err) + hub, err := NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) hub = getHubOrFail(t, hub.local, remote) @@ -51,7 +60,7 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) { require.False(t, item.State.UpToDate) require.False(t, item.State.Tainted) - didUpdate, err := item.Upgrade(false) + didUpdate, err := item.Upgrade(ctx, false) require.NoError(t, err) require.True(t, didUpdate) assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection") @@ -71,7 +80,9 @@ func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) { require.False(t, item.State.Installed) require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) - require.NoError(t, item.Install(false, false)) + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) require.True(t, item.State.Downloaded) require.True(t, item.State.Installed) @@ -100,11 +111,17 @@ func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) { require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.Installed) require.True(t, hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection").State.UpToDate) - hub, err = NewHub(hub.local, remote, true, nil) - require.NoError(t, err, "failed to download index: %s", err) + hub, err = NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") - didUpdate, err := item.Upgrade(false) + didUpdate, err := item.Upgrade(ctx, false) require.NoError(t, err) require.False(t, didUpdate) @@ -114,8 +131,11 @@ func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) { // getHubOrFail refreshes the hub state (load index, sync) and returns the singleton, or fails the test. func getHubOrFail(t *testing.T, local *csconfig.LocalHubCfg, remote *RemoteHubCfg) *Hub { - hub, err := NewHub(local, remote, false, nil) - require.NoError(t, err, "failed to load hub index") + hub, err := NewHub(local, remote, nil) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) return hub } @@ -132,7 +152,9 @@ func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *te require.False(t, item.State.Installed) require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) - require.NoError(t, item.Install(false, false)) + ctx := context.Background() + + require.NoError(t, item.Install(ctx, false, false)) require.True(t, item.State.Downloaded) require.True(t, item.State.Installed) @@ -166,14 +188,20 @@ func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *te // we just removed. Nor should it install the newly added scenario pushUpdateToCollectionInHub() - hub, err = NewHub(hub.local, remote, true, nil) - require.NoError(t, err, "failed to download index: %s", err) + hub, err = NewHub(hub.local, remote, nil) + require.NoError(t, err) + + err = hub.Update(ctx) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed) hub = getHubOrFail(t, hub.local, remote) item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") - didUpdate, err := item.Upgrade(false) + didUpdate, err := item.Upgrade(ctx, false) require.NoError(t, err) require.True(t, didUpdate) @@ -190,6 +218,6 @@ func assertCollectionDepsInstalled(t *testing.T, hub *Hub, collection string) { } func pushUpdateToCollectionInHub() { - responseByPath["/master/.index.json"] = fileToStringX("./testdata/index2.json") - responseByPath["/master/collections/crowdsecurity/test_collection.yaml"] = fileToStringX("./testdata/collection_v2.yaml") + responseByPath["/crowdsecurity/master/.index.json"] = fileToStringX("./testdata/index2.json") + responseByPath["/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml"] = fileToStringX("./testdata/collection_v2.yaml") } diff --git a/pkg/cwhub/leakybucket.go b/pkg/cwhub/leakybucket.go deleted file mode 100644 index 8143e9433ee..00000000000 --- a/pkg/cwhub/leakybucket.go +++ /dev/null @@ -1,53 +0,0 @@ -package cwhub - -// Resolve a symlink to find the hub item it points to. -// This file is used only by pkg/leakybucket - -import ( - "fmt" - "os" - "path/filepath" - "strings" -) - -// itemKey extracts the map key of an item (i.e. author/name) from its pathname. Follows a symlink if necessary. -func itemKey(itemPath string) (string, error) { - f, err := os.Lstat(itemPath) - if err != nil { - return "", fmt.Errorf("while performing lstat on %s: %w", itemPath, err) - } - - if f.Mode()&os.ModeSymlink == 0 { - // it's not a symlink, so the filename itsef should be the key - return filepath.Base(itemPath), nil - } - - // resolve the symlink to hub file - pathInHub, err := os.Readlink(itemPath) - if err != nil { - return "", fmt.Errorf("while reading symlink of %s: %w", itemPath, err) - } - - author := filepath.Base(filepath.Dir(pathInHub)) - - fname := filepath.Base(pathInHub) - fname = strings.TrimSuffix(fname, ".yaml") - fname = strings.TrimSuffix(fname, ".yml") - - return fmt.Sprintf("%s/%s", author, fname), nil -} - -// GetItemByPath retrieves an item from the hub index based on its local path. -func (h *Hub) GetItemByPath(itemType string, itemPath string) (*Item, error) { - itemKey, err := itemKey(itemPath) - if err != nil { - return nil, err - } - - item := h.GetItem(itemType, itemKey) - if item == nil { - return nil, fmt.Errorf("%s not found in %s", itemKey, itemType) - } - - return item, nil -} diff --git a/pkg/cwhub/relativepath.go b/pkg/cwhub/relativepath.go new file mode 100644 index 00000000000..bcd4c576840 --- /dev/null +++ b/pkg/cwhub/relativepath.go @@ -0,0 +1,28 @@ +package cwhub + +import ( + "path/filepath" + "strings" +) + +// relativePathComponents returns the list of path components after baseDir. +// If path is not inside baseDir, it returns an empty slice. +func relativePathComponents(path string, baseDir string) []string { + absPath, err := filepath.Abs(path) + if err != nil { + return []string{} + } + + absBaseDir, err := filepath.Abs(baseDir) + if err != nil { + return []string{} + } + + // is path inside baseDir? + relPath, err := filepath.Rel(absBaseDir, absPath) + if err != nil || strings.HasPrefix(relPath, "..") || relPath == "." { + return []string{} + } + + return strings.Split(relPath, string(filepath.Separator)) +} diff --git a/pkg/cwhub/relativepath_test.go b/pkg/cwhub/relativepath_test.go new file mode 100644 index 00000000000..11eba566064 --- /dev/null +++ b/pkg/cwhub/relativepath_test.go @@ -0,0 +1,72 @@ +package cwhub + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRelativePathComponents(t *testing.T) { + tests := []struct { + name string + path string + baseDir string + expected []string + }{ + { + name: "Path within baseDir", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project", + expected: []string{"src", "file.go"}, + }, + { + name: "Path is baseDir", + path: "/home/user/project", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path outside baseDir", + path: "/home/user/otherproject/src/file.go", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path is subdirectory of baseDir", + path: "/home/user/project/src/", + baseDir: "/home/user/project", + expected: []string{"src"}, + }, + { + name: "Relative paths", + path: "project/src/file.go", + baseDir: "project", + expected: []string{"src", "file.go"}, + }, + { + name: "BaseDir with trailing slash", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project/", + expected: []string{"src", "file.go"}, + }, + { + name: "Empty baseDir", + path: "/home/user/project/src/file.go", + baseDir: "", + expected: []string{}, + }, + { + name: "Empty path", + path: "", + baseDir: "/home/user/project", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := relativePathComponents(tt.path, tt.baseDir) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/cwhub/remote.go b/pkg/cwhub/remote.go index c1eb5a7080f..8d2dc2dbb94 100644 --- a/pkg/cwhub/remote.go +++ b/pkg/cwhub/remote.go @@ -1,16 +1,21 @@ package cwhub import ( + "context" "fmt" - "io" - "net/http" + "net/url" + + "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/go-cs-lib/downloader" ) // RemoteHubCfg is used to retrieve index and items from the remote hub. type RemoteHubCfg struct { - Branch string - URLTemplate string - IndexPath string + Branch string + URLTemplate string + IndexPath string + EmbedItemContent bool } // urlTo builds the URL to download a file from the remote hub. @@ -27,35 +32,53 @@ func (r *RemoteHubCfg) urlTo(remotePath string) (string, error) { return fmt.Sprintf(r.URLTemplate, r.Branch, remotePath), nil } +// addURLParam adds the "with_content=true" parameter to the URL if it's not already present. +func addURLParam(rawURL string, param string, value string) (string, error) { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + + query := parsedURL.Query() + + if _, exists := query[param]; !exists { + query.Add(param, value) + } + + parsedURL.RawQuery = query.Encode() + + return parsedURL.String(), nil +} + // fetchIndex downloads the index from the hub and returns the content. -func (r *RemoteHubCfg) fetchIndex() ([]byte, error) { +func (r *RemoteHubCfg) fetchIndex(ctx context.Context, destPath string) (bool, error) { if r == nil { - return nil, ErrNilRemoteHub + return false, ErrNilRemoteHub } url, err := r.urlTo(r.IndexPath) if err != nil { - return nil, fmt.Errorf("failed to build hub index request: %w", err) + return false, fmt.Errorf("failed to build hub index request: %w", err) } - resp, err := hubClient.Get(url) - if err != nil { - return nil, fmt.Errorf("failed http request for hub index: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusNotFound { - return nil, IndexNotFoundError{url, r.Branch} + if r.EmbedItemContent { + url, err = addURLParam(url, "with_content", "true") + if err != nil { + return false, fmt.Errorf("failed to add 'with_content' parameter to URL: %w", err) } - - return nil, fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) } - body, err := io.ReadAll(resp.Body) + downloaded, err := downloader. + New(). + WithHTTPClient(hubClient). + ToFile(destPath). + WithETagFn(downloader.SHA256). + CompareContent(). + WithLogger(logrus.WithField("url", url)). + Download(ctx, url) if err != nil { - return nil, fmt.Errorf("failed to read request answer for hub index: %w", err) + return false, err } - return body, nil + return downloaded, nil } diff --git a/pkg/cwhub/sync.go b/pkg/cwhub/sync.go index 8ce91dc2193..c82822e64ef 100644 --- a/pkg/cwhub/sync.go +++ b/pkg/cwhub/sync.go @@ -1,10 +1,8 @@ package cwhub import ( - "crypto/sha256" - "encoding/hex" + "errors" "fmt" - "io" "os" "path/filepath" "slices" @@ -14,53 +12,66 @@ import ( "github.com/Masterminds/semver/v3" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/downloader" ) func isYAMLFileName(path string) bool { return strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml") } -// linkTarget returns the target of a symlink, or empty string if it's dangling. -func linkTarget(path string, logger *logrus.Logger) (string, error) { - hubpath, err := os.Readlink(path) - if err != nil { - return "", fmt.Errorf("unable to read symlink: %s", path) - } +// resolveSymlink returns the ultimate target path of a symlink +// returns error if the symlink is dangling or too many symlinks are followed +func resolveSymlink(path string) (string, error) { + const maxSymlinks = 10 // Prevent infinite loops + for range maxSymlinks { + fi, err := os.Lstat(path) + if err != nil { + return "", err // dangling link + } - logger.Tracef("symlink %s -> %s", path, hubpath) + if fi.Mode()&os.ModeSymlink == 0 { + // found the target + return path, nil + } - _, err = os.Lstat(hubpath) - if os.IsNotExist(err) { - logger.Warningf("link target does not exist: %s -> %s", path, hubpath) - return "", nil + path, err = os.Readlink(path) + if err != nil { + return "", err + } + + // relative to the link's directory? + if !filepath.IsAbs(path) { + path = filepath.Join(filepath.Dir(path), path) + } } - return hubpath, nil + return "", errors.New("too many levels of symbolic links") } -func getSHA256(filepath string) (string, error) { - f, err := os.Open(filepath) +// isPathInside checks if a path is inside the given directory +// it can return false negatives if the filesystem is case insensitive +func isPathInside(path, dir string) (bool, error) { + absFilePath, err := filepath.Abs(path) if err != nil { - return "", fmt.Errorf("unable to open '%s': %w", filepath, err) + return false, err } - defer f.Close() - - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return "", fmt.Errorf("unable to calculate sha256 of '%s': %w", filepath, err) + absDir, err := filepath.Abs(dir) + if err != nil { + return false, err } - return hex.EncodeToString(h.Sum(nil)), nil + return strings.HasPrefix(absFilePath, absDir), nil } // information used to create a new Item, from a file path. type itemFileInfo struct { - inhub bool fname string stage string ftype string fauthor string + inhub bool } func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo, error) { @@ -69,59 +80,78 @@ func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo hubDir := h.local.HubDir installDir := h.local.InstallDir - subs := strings.Split(path, string(os.PathSeparator)) + subsHub := relativePathComponents(path, hubDir) + subsInstall := relativePathComponents(path, installDir) - logger.Tracef("path:%s, hubdir:%s, installdir:%s", path, hubDir, installDir) - logger.Tracef("subs:%v", subs) - // we're in hub (~/.hub/hub/) - if strings.HasPrefix(path, hubDir) { + switch { + case len(subsHub) > 0: logger.Tracef("in hub dir") - //.../hub/parsers/s00-raw/crowdsec/skip-pretag.yaml - //.../hub/scenarios/crowdsec/ssh_bf.yaml - //.../hub/profiles/crowdsec/linux.yaml - if len(subs) < 4 { - return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) + // .../hub/parsers/s00-raw/crowdsecurity/skip-pretag.yaml + // .../hub/scenarios/crowdsecurity/ssh_bf.yaml + // .../hub/profiles/crowdsecurity/linux.yaml + if len(subsHub) < 3 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsHub)) + } + + ftype := subsHub[0] + if !slices.Contains(ItemTypes, ftype) { + // this doesn't really happen anymore, because we only scan the {hubtype} directories + return nil, fmt.Errorf("unknown configuration type '%s'", ftype) + } + + stage := "" + fauthor := subsHub[1] + fname := subsHub[2] + + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsHub[1] + fauthor = subsHub[2] + fname = subsHub[3] } ret = &itemFileInfo{ inhub: true, - fname: subs[len(subs)-1], - fauthor: subs[len(subs)-2], - stage: subs[len(subs)-3], - ftype: subs[len(subs)-4], + ftype: ftype, + stage: stage, + fauthor: fauthor, + fname: fname, } - } else if strings.HasPrefix(path, installDir) { // we're in install /etc/crowdsec//... + + case len(subsInstall) > 0: logger.Tracef("in install dir") - if len(subs) < 3 { - return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) + + // .../config/parser/stage/file.yaml + // .../config/postoverflow/stage/file.yaml + // .../config/scenarios/scenar.yaml + // .../config/collections/linux.yaml //file is empty + + if len(subsInstall) < 2 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsInstall)) } - ///.../config/parser/stage/file.yaml - ///.../config/postoverflow/stage/file.yaml - ///.../config/scenarios/scenar.yaml - ///.../config/collections/linux.yaml //file is empty + + // this can be in any number of subdirs, we join them to compose the item name + + ftype := subsInstall[0] + stage := "" + fname := strings.Join(subsInstall[1:], "/") + + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsInstall[1] + fname = strings.Join(subsInstall[2:], "/") + } + ret = &itemFileInfo{ inhub: false, - fname: subs[len(subs)-1], - stage: subs[len(subs)-2], - ftype: subs[len(subs)-3], + ftype: ftype, + stage: stage, fauthor: "", + fname: fname, } - } else { + default: return nil, fmt.Errorf("file '%s' is not from hub '%s' nor from the configuration directory '%s'", path, hubDir, installDir) } - logger.Tracef("stage:%s ftype:%s", ret.stage, ret.ftype) - - if ret.ftype != PARSERS && ret.ftype != POSTOVERFLOWS { - if !slices.Contains(ItemTypes, ret.stage) { - return nil, fmt.Errorf("unknown configuration type for file '%s'", path) - } - - ret.ftype = ret.stage - ret.stage = "" - } - logger.Tracef("CORRECTED [%s] by [%s] in stage [%s] of type [%s]", ret.fname, ret.fauthor, ret.stage, ret.ftype) return ret, nil @@ -180,7 +210,7 @@ func newLocalItem(h *Hub, path string, info *itemFileInfo) (*Item, error) { err = yaml.Unmarshal(itemContent, &itemName) if err != nil { - return nil, fmt.Errorf("failed to unmarshal %s: %w", path, err) + return nil, fmt.Errorf("failed to parse %s: %w", path, err) } if itemName.Name != "" { @@ -191,8 +221,6 @@ func newLocalItem(h *Hub, path string, info *itemFileInfo) (*Item, error) { } func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { - hubpath := "" - if err != nil { h.logger.Debugf("while syncing hub dir: %s", err) // there is a path error, we ignore the file @@ -205,45 +233,67 @@ func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { return err } + // permission errors, files removed while reading, etc. + if f == nil { + return nil + } + + if f.IsDir() { + // if a directory starts with a dot, we don't traverse it + // - single dot prefix is hidden by unix convention + // - double dot prefix is used by k8s to mount config maps + if strings.HasPrefix(f.Name(), ".") { + h.logger.Tracef("skipping hidden directory %s", path) + return filepath.SkipDir + } + + // keep traversing + return nil + } + // we only care about YAML files - if f == nil || f.IsDir() || !isYAMLFileName(f.Name()) { + if !isYAMLFileName(f.Name()) { return nil } info, err := h.getItemFileInfo(path, h.logger) if err != nil { - return err + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil } - // non symlinks are local user files or hub files - if f.Type()&os.ModeSymlink == 0 { - h.logger.Tracef("%s is not a symlink", path) - - if !info.inhub { - h.logger.Tracef("%s is a local file, skip", path) + // follow the link to see if it falls in the hub directory + // if it's not a link, target == path + target, err := resolveSymlink(path) + if err != nil { + // target does not exist, the user might have removed the file + // or switched to a hub branch without it; or symlink loop + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } - item, err := newLocalItem(h, path, info) - if err != nil { - return err - } + targetInHub, err := isPathInside(target, h.local.HubDir) + if err != nil { + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } - h.addItem(item) + // local (custom) item if the file or link target is not inside the hub dir + if !targetInHub { + h.logger.Tracef("%s is a local file, skip", path) - return nil - } - } else { - hubpath, err = linkTarget(path, h.logger) + item, err := newLocalItem(h, path, info) if err != nil { return err } - if hubpath == "" { - // target does not exist, the user might have removed the file - // or switched to a hub branch without it - return nil - } + h.addItem(item) + + return nil } + hubpath := target + // try to find which configuration item it is h.logger.Tracef("check [%s] of %s", info.fname, info.ftype) @@ -288,6 +338,8 @@ func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { return err } + h.pathIndex[path] = item + return nil } @@ -465,7 +517,7 @@ func (h *Hub) localSync() error { func (i *Item) setVersionState(path string, inhub bool) error { var err error - i.State.LocalHash, err = getSHA256(path) + i.State.LocalHash, err = downloader.SHA256(path) if err != nil { return fmt.Errorf("failed to get sha256 of %s: %w", path, err) } diff --git a/pkg/cwhub/testdata/index1.json b/pkg/cwhub/testdata/index1.json index a7e6ef6153b..59548bda379 100644 --- a/pkg/cwhub/testdata/index1.json +++ b/pkg/cwhub/testdata/index1.json @@ -10,7 +10,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -34,7 +33,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "test_collection : foobar", "author": "crowdsecurity", "labels": null, @@ -52,7 +50,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -73,7 +70,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -89,7 +85,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -107,7 +102,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -118,4 +112,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/cwhub/testdata/index2.json b/pkg/cwhub/testdata/index2.json index 7f97ebf2308..41c4ccba83a 100644 --- a/pkg/cwhub/testdata/index2.json +++ b/pkg/cwhub/testdata/index2.json @@ -10,7 +10,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -38,7 +37,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "test_collection : foobar", "author": "crowdsecurity", "labels": null, @@ -57,7 +55,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "foobar collection : foobar", "author": "crowdsecurity", "labels": null, @@ -78,7 +75,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -94,7 +90,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "A foobar parser", "author": "crowdsecurity", "labels": null @@ -112,7 +107,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -132,7 +126,6 @@ } }, "long_description": "bG9uZyBkZXNjcmlwdGlvbgo=", - "content": "bG9uZyBkZXNjcmlwdGlvbgo=", "description": "a foobar scenario", "author": "crowdsecurity", "labels": { @@ -143,4 +136,4 @@ } } } -} \ No newline at end of file +} diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go new file mode 100644 index 00000000000..4036b63cf00 --- /dev/null +++ b/pkg/cwversion/component/component.go @@ -0,0 +1,34 @@ +package component + +// Package component provides functionality for managing the registration of +// optional, compile-time components in the system. This is meant as a space +// saving measure, separate from feature flags (package pkg/fflag) which are +// only enabled/disabled at runtime. + +// Built is a map of all the known components, and whether they are built-in or not. +// This is populated as soon as possible by the respective init() functions +var Built = map[string]bool { + "datasource_appsec": false, + "datasource_cloudwatch": false, + "datasource_docker": false, + "datasource_file": false, + "datasource_journalctl": false, + "datasource_k8s-audit": false, + "datasource_kafka": false, + "datasource_kinesis": false, + "datasource_loki": false, + "datasource_s3": false, + "datasource_syslog": false, + "datasource_wineventlog":false, + "cscli_setup": false, +} + +func Register(name string) { + if _, ok := Built[name]; !ok { + // having a list of the disabled components is essential + // to debug users' issues + panic("cannot register unknown compile-time component: " + name) + } + + Built[name] = true +} diff --git a/pkg/cwversion/constraint/constraint.go b/pkg/cwversion/constraint/constraint.go new file mode 100644 index 00000000000..67593f9ebbc --- /dev/null +++ b/pkg/cwversion/constraint/constraint.go @@ -0,0 +1,32 @@ +package constraint + +import ( + "fmt" + + goversion "github.com/hashicorp/go-version" +) + +const ( + Parser = ">= 1.0, <= 3.0" + Scenario = ">= 1.0, <= 3.0" + API = "v1" + Acquis = ">= 1.0, < 2.0" +) + +func Satisfies(strvers string, constraint string) (bool, error) { + vers, err := goversion.NewVersion(strvers) + if err != nil { + return false, fmt.Errorf("failed to parse '%s': %w", strvers, err) + } + + constraints, err := goversion.NewConstraint(constraint) + if err != nil { + return false, fmt.Errorf("failed to parse constraint '%s'", constraint) + } + + if !constraints.Check(vers) { + return false, nil + } + + return true, nil +} diff --git a/pkg/cwversion/version.go b/pkg/cwversion/version.go index 6f85704d8e5..2cb7de13e18 100644 --- a/pkg/cwversion/version.go +++ b/pkg/cwversion/version.go @@ -1,111 +1,66 @@ package cwversion import ( - "encoding/json" "fmt" - "log" - "net/http" - "runtime" "strings" - goversion "github.com/hashicorp/go-version" - + "github.com/crowdsecurity/go-cs-lib/maptools" "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" ) var ( - Codename string // = "SoumSoum" - System = runtime.GOOS // = "linux" + Codename string // = "SoumSoum" Libre2 = "WebAssembly" ) -const ( - Constraint_parser = ">= 1.0, <= 3.0" - Constraint_scenario = ">= 1.0, <= 3.0" - Constraint_api = "v1" - Constraint_acquis = ">= 1.0, < 2.0" -) +func FullString() string { + dsBuilt := map[string]struct{}{} + dsExcluded := map[string]struct{}{} -func versionWithTag() string { - ret := version.Version + for ds, built := range component.Built { + if built { + dsBuilt[ds] = struct{}{} + continue + } - if !strings.HasSuffix(ret, version.Tag) { - ret += fmt.Sprintf("-%s", version.Tag) + dsExcluded[ds] = struct{}{} } - return ret -} - -func ShowStr() string { - ret := fmt.Sprintf("version: %s", versionWithTag()) + ret := fmt.Sprintf("version: %s\n", version.String()) ret += fmt.Sprintf("Codename: %s\n", Codename) ret += fmt.Sprintf("BuildDate: %s\n", version.BuildDate) ret += fmt.Sprintf("GoVersion: %s\n", version.GoVersion) - ret += fmt.Sprintf("Platform: %s\n", System) - - return ret -} - -func Show() { - log.Printf("version: %s", versionWithTag()) - log.Printf("Codename: %s", Codename) - log.Printf("BuildDate: %s", version.BuildDate) - log.Printf("GoVersion: %s", version.GoVersion) - log.Printf("Platform: %s\n", System) - log.Printf("libre2: %s\n", Libre2) - log.Printf("Constraint_parser: %s", Constraint_parser) - log.Printf("Constraint_scenario: %s", Constraint_scenario) - log.Printf("Constraint_api: %s", Constraint_api) - log.Printf("Constraint_acquis: %s", Constraint_acquis) -} - -func VersionStr() string { - return fmt.Sprintf("%s-%s-%s", version.Version, System, version.Tag) -} - -func VersionStrip() string { - version := strings.Split(version.Version, "~") - version = strings.Split(version[0], "-") - - return version[0] -} - -func Satisfies(strvers string, constraint string) (bool, error) { - vers, err := goversion.NewVersion(strvers) - if err != nil { - return false, fmt.Errorf("failed to parse '%s' : %v", strvers, err) + ret += fmt.Sprintf("Platform: %s\n", version.System) + ret += fmt.Sprintf("libre2: %s\n", Libre2) + ret += fmt.Sprintf("User-Agent: %s\n", useragent.Default()) + ret += fmt.Sprintf("Constraint_parser: %s\n", constraint.Parser) + ret += fmt.Sprintf("Constraint_scenario: %s\n", constraint.Scenario) + ret += fmt.Sprintf("Constraint_api: %s\n", constraint.API) + ret += fmt.Sprintf("Constraint_acquis: %s\n", constraint.Acquis) + + built := "(none)" + + if len(dsBuilt) > 0 { + built = strings.Join(maptools.SortedKeys(dsBuilt), ", ") } - constraints, err := goversion.NewConstraint(constraint) - if err != nil { - return false, fmt.Errorf("failed to parse constraint '%s'", constraint) - } + ret += fmt.Sprintf("Built-in optional components: %s\n", built) - if !constraints.Check(vers) { - return false, nil + if len(dsExcluded) > 0 { + ret += fmt.Sprintf("Excluded components: %s\n", strings.Join(maptools.SortedKeys(dsExcluded), ", ")) } - return true, nil + return ret } -// Latest return latest crowdsec version based on github -func Latest() (string, error) { - latest := make(map[string]interface{}) - - resp, err := http.Get("https://version.crowdsec.net/latest") - if err != nil { - return "", err - } - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&latest) - if err != nil { - return "", err - } - - if _, ok := latest["name"]; !ok { - return "", fmt.Errorf("unable to find latest release name from github api: %+v", latest) - } +// VersionStrip remove the tag from the version string, used to match with a hub branch +func VersionStrip() string { + ret := strings.Split(version.Version, "~") + ret = strings.Split(ret[0], "-") - return latest["name"].(string), nil + return ret[0] } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 0502c25312d..ede9c89fe9a 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -10,8 +10,6 @@ import ( "time" "github.com/mattn/go-sqlite3" - - "github.com/davecgh/go-spew/spew" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -28,124 +26,29 @@ import ( ) const ( - paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' - defaultLimit = 100 // default limit of element to returns when query alerts - bulkSize = 50 // bulk size when create alerts - maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered + paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' + defaultLimit = 100 // default limit of element to returns when query alerts + alertCreateBulkSize = 50 // bulk size when create alerts + maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered ) -func formatAlertCN(source models.Source) string { - cn := source.Cn - - if source.AsNumber != "" { - cn += "/" + source.AsNumber - } - - return cn -} - -func formatAlertSource(alert *models.Alert) string { - if alert.Source == nil || alert.Source.Scope == nil || *alert.Source.Scope == "" { - return "empty source" - } - - if *alert.Source.Scope == types.Ip { - ret := "ip " + *alert.Source.Value - - cn := formatAlertCN(*alert.Source) - if cn != "" { - ret += " (" + cn + ")" - } - - return ret - } - - if *alert.Source.Scope == types.Range { - ret := "range " + *alert.Source.Value - - cn := formatAlertCN(*alert.Source) - if cn != "" { - ret += " (" + cn + ")" - } - - return ret - } - - return *alert.Source.Scope + " " + *alert.Source.Value -} - -func formatAlertAsString(machineID string, alert *models.Alert) []string { - src := formatAlertSource(alert) - - msg := "empty scenario" - if alert.Scenario != nil && *alert.Scenario != "" { - msg = *alert.Scenario - } else if alert.Message != nil && *alert.Message != "" { - msg = *alert.Message - } - - reason := fmt.Sprintf("%s by %s", msg, src) - - if len(alert.Decisions) == 0 { - return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} - } - - var retStr []string - - if alert.Decisions[0].Origin != nil && *alert.Decisions[0].Origin == types.CscliImportOrigin { - return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} - } - - for i, decisionItem := range alert.Decisions { - decision := "" - if alert.Simulated != nil && *alert.Simulated { - decision = "(simulated alert)" - } else if decisionItem.Simulated != nil && *decisionItem.Simulated { - decision = "(simulated decision)" - } - - if log.GetLevel() >= log.DebugLevel { - /*spew is expensive*/ - log.Debugf("%s", spew.Sdump(decisionItem)) - } - - if len(alert.Decisions) > 1 { - reason = fmt.Sprintf("%s for %d/%d decisions", msg, i+1, len(alert.Decisions)) - } - - var machineIDOrigin string - if machineID == "" { - machineIDOrigin = *decisionItem.Origin - } else { - machineIDOrigin = fmt.Sprintf("%s/%s", machineID, *decisionItem.Origin) - } - - decision += fmt.Sprintf("%s %s on %s %s", *decisionItem.Duration, - *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) - retStr = append(retStr, - fmt.Sprintf("(%s) %s : %s", machineIDOrigin, reason, decision)) - } - - return retStr -} - // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { - return "", fmt.Errorf("alert UUID is empty") + return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) } - //alert wasn't found, insert it (expected hotpath) + // alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } @@ -153,14 +56,14 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) return alertIDs[0], nil } - //this should never happen + // this should never happen if len(alerts) > 1 { return "", fmt.Errorf("multiple alerts found for uuid %s", alertItem.UUID) } log.Infof("Alert %s already exists, checking associated decisions", alertItem.UUID) - //alert is found, check for any missing decisions + // alert is found, check for any missing decisions newUuids := make([]string, len(alertItem.Decisions)) for i, decItem := range alertItem.Decisions { @@ -203,14 +106,16 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) } } - //add missing decisions + // add missing decisions log.Debugf("Adding %d missing decisions to alert %s", len(missingDecisions), foundAlert.UUID) decisionBuilders := []*ent.DecisionCreate{} for _, decisionItem := range missingDecisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) /*if the scope is IP or Range, convert the value to integers */ if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { @@ -227,7 +132,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) continue } - //use the created_at from the alert instead + // use the created_at from the alert instead alertTime, err := time.Parse(time.RFC3339, alertItem.CreatedAt) if err != nil { log.Errorf("unable to parse alert time %s : %s", alertItem.CreatedAt, err) @@ -260,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { return "", fmt.Errorf("creating alert decisions: %w", err) } @@ -268,12 +173,12 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) decisions = append(decisions, decisionsCreateRet...) } - //now that we bulk created missing decisions, let's update the alert + // now that we bulk created missing decisions, let's update the alert decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } @@ -286,13 +191,13 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { - return 0, 0, 0, fmt.Errorf("nil alert") + return 0, 0, 0, errors.New("nil alert") } if alertItem.StartAt == nil { - return 0, 0, 0, fmt.Errorf("nil start_at") + return 0, 0, 0, errors.New("nil start_at") } startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) @@ -301,7 +206,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in } if alertItem.StopAt == nil { - return 0, 0, 0, fmt.Errorf("nil stop_at") + return 0, 0, 0, errors.New("nil stop_at") } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) @@ -336,9 +241,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetLeakSpeed(*alertItem.Leakspeed). SetSimulated(*alertItem.Simulated). SetScenarioVersion(*alertItem.ScenarioVersion). - SetScenarioHash(*alertItem.ScenarioHash) + SetScenarioHash(*alertItem.ScenarioHash). + SetRemediation(true) // it's from CAPI, we always have decisions - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -347,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } @@ -367,8 +273,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in valueList := make([]string, 0, len(alertItem.Decisions)) for _, decisionItem := range alertItem.Decisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) if decisionItem.Duration == nil { log.Warning("nil duration in community decision") @@ -439,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.ValueIn(deleteChunk...), - )).Exec(c.CTX) + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -455,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -483,12 +391,14 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, inserted, deleted, nil } -func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { decisionCreate := []*ent.DecisionCreate{} for _, decisionItem := range decisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + var ( + start_ip, start_sfx, end_ip, end_sfx int64 + sz int + ) duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { @@ -526,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return nil, nil } - ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) if err != nil { return nil, err } @@ -534,33 +444,36 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return ret, nil } -func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { alertBuilders := []*ent.AlertCreate{} alertDecisions := [][]*ent.Decision{} for _, alertItem := range alerts { - var metas []*ent.Meta - var events []*ent.Event + var ( + metas []*ent.Meta + events []*ent.Event + ) startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + c.Log.Errorf("creating alert: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) startAtTime = time.Now().UTC() } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + c.Log.Errorf("creating alert: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) stopAtTime = time.Now().UTC() } + /*display proper alert in logs*/ - for _, disp := range formatAlertAsString(machineID, alertItem) { + for _, disp := range alertItem.FormatAsStrings(machineID, log.StandardLogger()) { c.Log.Info(disp) } - //let's track when we strip or drop data, notify outside of loop to avoid spam + // let's track when we strip or drop data, notify outside of loop to avoid spam stripped := false dropped := false @@ -570,7 +483,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for i, eventItem := range alertItem.Events { ts, err := time.Parse(time.RFC3339, *eventItem.Timestamp) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + c.Log.Errorf("creating alert: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) ts = time.Now().UTC() } @@ -580,7 +493,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, errors.Wrapf(MarshalFail, "event meta '%v' : %s", eventItem.Meta, err) } - //the serialized field is too big, let's try to progressively strip it + // the serialized field is too big, let's try to progressively strip it if event.SerializedValidator(string(marshallMetas)) != nil { stripped = true @@ -606,7 +519,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ stripSize /= 2 } - //nothing worked, drop it + // nothing worked, drop it if !valid { dropped = true stripped = false @@ -627,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -635,15 +548,31 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ if len(alertItem.Meta) > 0 { metaBulk := make([]*ent.MetaCreate, len(alertItem.Meta)) + for i, metaItem := range alertItem.Meta { + key := metaItem.Key + value := metaItem.Value + + if len(metaItem.Value) > 4095 { + c.Log.Warningf("truncated meta %s: value too long", metaItem.Key) + + value = value[:4095] + } + + if len(metaItem.Key) > 255 { + c.Log.Warningf("truncated meta %s: key too long", metaItem.Key) + + key = key[:255] + } + metaBulk[i] = c.Ent.Meta.Create(). - SetKey(metaItem.Key). - SetValue(metaItem.Value) + SetKey(key). + SetValue(value) } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert meta: %s", err) + c.Log.Warningf("error creating alert meta: %s", err) } } @@ -651,7 +580,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { return nil, fmt.Errorf("creating alert decisions: %w", err) } @@ -691,6 +620,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ SetSimulated(*alertItem.Simulated). SetScenarioVersion(*alertItem.ScenarioVersion). SetScenarioHash(*alertItem.ScenarioHash). + SetRemediation(alertItem.Remediation). SetUUID(alertItem.UUID). AddEvents(events...). AddMetas(metas...) @@ -708,7 +638,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, nil } - alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } @@ -725,7 +655,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for retry < maxLockRetries { // so much for the happy path... but sqlite3 errors work differently - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) if err == nil { break } @@ -754,18 +684,20 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return ret, nil } -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { - var owner *ent.Machine - var err error +func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) { + var ( + owner *ent.Machine + err error + ) if machineID != "" { - owner, err = c.QueryMachineByID(machineID) + owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } - c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineID) + c.Log.Debugf("creating alert: machine %s doesn't exist", machineID) owner = nil } @@ -773,11 +705,11 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str c.Log.Debugf("writing %d items", len(alertList)) - alertChunks := slicetools.Chunks(alertList, bulkSize) + alertChunks := slicetools.Chunks(alertList, alertCreateBulkSize) alertIDs := []string{} for _, alertChunk := range alertChunks { - ids, err := c.createAlertChunk(machineID, owner, alertChunk) + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -785,31 +717,187 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str alertIDs = append(alertIDs, ids...) } + if owner != nil { + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) + if err != nil { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } + } + return alertIDs, nil } +func handleSimulatedFilter(filter map[string][]string, predicates *[]predicate.Alert) { + /* the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ + if v, ok := filter["simulated"]; ok && v[0] == "false" { + *predicates = append(*predicates, alert.SimulatedEQ(false)) + } +} + +func handleOriginFilter(filter map[string][]string, predicates *[]predicate.Alert) { + if _, ok := filter["origin"]; ok { + filter["include_capi"] = []string{"true"} + } +} + +func handleScopeFilter(scope string, predicates *[]predicate.Alert) { + if strings.ToLower(scope) == "ip" { + scope = types.Ip + } else if strings.ToLower(scope) == "range" { + scope = types.Range + } + + *predicates = append(*predicates, alert.SourceScopeEQ(scope)) +} + +func handleTimeFilters(param, value string, predicates *[]predicate.Alert) error { + duration, err := ParseDuration(value) + if err != nil { + return fmt.Errorf("while parsing duration: %w", err) + } + + timePoint := time.Now().UTC().Add(-duration) + if timePoint.IsZero() { + return fmt.Errorf("empty time now() - %s", timePoint.String()) + } + + switch param { + case "since": + *predicates = append(*predicates, alert.StartedAtGTE(timePoint)) + case "created_before": + *predicates = append(*predicates, alert.CreatedAtLTE(timePoint)) + case "until": + *predicates = append(*predicates, alert.StartedAtLTE(timePoint)) + } + + return nil +} + +func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } +} + +func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip < query.start_ip + alert.HasDecisionsWith(decision.StartIPLT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix <= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip > query.end_ip + alert.HasDecisionsWith(decision.EndIPGT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix >= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), + ), + ), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip > query.start_ip + alert.HasDecisionsWith(decision.StartIPGT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix >= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip < query.end_ip + alert.HasDecisionsWith(decision.EndIPLT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix <= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), + ), + ), + )) + } +} + +func handleIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error { + if ip_sz == 4 { + handleIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz == 16 { + handleIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz != 0 { + return errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + } + + return nil +} + +func handleIncludeCapiFilter(value string, predicates *[]predicate.Alert) error { + if value == "false" { + *predicates = append(*predicates, alert.And( + // do not show alerts with active decisions having origin CAPI or lists + alert.And( + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), + ), + alert.Not( + alert.And( + // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI + alert.Not(alert.HasDecisions()), + alert.Or( + alert.SourceScopeHasPrefix(types.ListOrigin+":"), + alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), + ), + ), + ), + )) + } else if value != "true" { + log.Errorf("invalid bool '%s' for include_capi", value) + } + + return nil +} + func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { predicates := make([]predicate.Alert, 0) - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var hasActiveDecision bool - var ip_sz int - var contains = true + var ( + err error + start_ip, start_sfx, end_ip, end_sfx int64 + hasActiveDecision bool + ip_sz int + ) + + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ - /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ - if v, ok := filter["simulated"]; ok { - if v[0] == "false" { - predicates = append(predicates, alert.SimulatedEQ(false)) - } - } - - if _, ok := filter["origin"]; ok { - filter["include_capi"] = []string{"true"} - } + handleSimulatedFilter(filter, &predicates) + handleOriginFilter(filter, &predicates) for param, value := range filter { switch param { @@ -819,14 +907,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - var scope = value[0] - if strings.ToLower(scope) == "ip" { - scope = types.Ip - } else if strings.ToLower(scope) == "range" { - scope = types.Range - } - - predicates = append(predicates, alert.SourceScopeEQ(scope)) + handleScopeFilter(value[0], &predicates) case "value": predicates = append(predicates, alert.SourceValueEQ(value[0])) case "scenario": @@ -836,69 +917,17 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err != nil { return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) } - case "since": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) - } - - predicates = append(predicates, alert.StartedAtGTE(since)) - case "created_before": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) + case "since", "created_before", "until": + if err := handleTimeFilters(param, value[0], &predicates); err != nil { + return nil, err } - - predicates = append(predicates, alert.CreatedAtLTE(since)) - case "until": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - until := time.Now().UTC().Add(-duration) - if until.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", until.String()) - } - - predicates = append(predicates, alert.StartedAtLTE(until)) case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) case "origin": predicates = append(predicates, alert.HasDecisionsWith(decision.OriginEQ(value[0]))) - case "include_capi": //allows to exclude one or more specific origins - if value[0] == "false" { - predicates = append(predicates, alert.And( - //do not show alerts with active decisions having origin CAPI or lists - alert.And( - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), - ), - alert.Not( - alert.And( - //do not show neither alerts with no decisions if the Source Scope is lists: or CAPI - alert.Not(alert.HasDecisions()), - alert.Or( - alert.SourceScopeHasPrefix(types.ListOrigin+":"), - alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), - ), - ), - ), - ), - ) - - } else if value[0] != "true" { - log.Errorf("Invalid bool '%s' for include_capi", value[0]) + case "include_capi": // allows to exclude one or more specific origins + if err = handleIncludeCapiFilter(value[0], &predicates); err != nil { + return nil, err } case "has_active_decision": if hasActiveDecision, err = strconv.ParseBool(value[0]); err != nil { @@ -923,72 +952,8 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e } } - if ip_sz == 4 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } - } else if ip_sz == 16 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - //matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - //decision.start_ip < query.start_ip - alert.HasDecisionsWith(decision.StartIPLT(start_ip)), - alert.And( - //decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - //decision.start_suffix <= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), - )), - alert.Or( - //decision.end_ip > query.end_ip - alert.HasDecisionsWith(decision.EndIPGT(end_ip)), - alert.And( - //decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - //decision.end_suffix >= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), - ), - ), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - //matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - //decision.start_ip > query.start_ip - alert.HasDecisionsWith(decision.StartIPGT(start_ip)), - alert.And( - //decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - //decision.start_suffix >= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), - )), - alert.Or( - //decision.end_ip < query.end_ip - alert.HasDecisionsWith(decision.EndIPLT(end_ip)), - alert.And( - //decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - //decision.end_suffix <= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), - ), - ), - )) - } - } else if ip_sz != 0 { - return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { + return nil, err } return predicates, nil @@ -1003,24 +968,20 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str return alerts.Where(preds...), nil } -func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) { +func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) { var res []struct { Scenario string Count int } - ctx := context.Background() - query := c.Ent.Alert.Query() query, err := BuildAlertRequestFromFilter(query, filters) - if err != nil { return nil, fmt.Errorf("failed to build alert request: %w", err) } err = query.GroupBy(alert.FieldScenario).Aggregate(ent.Count()).Scan(ctx, &res) - if err != nil { return nil, fmt.Errorf("failed to count alerts per scenario: %w", err) } @@ -1034,11 +995,11 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default if val, ok := filter["sort"]; ok { @@ -1071,7 +1032,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return nil, err } - //only if with_decisions is present and set to false, we exclude this + // only if with_decisions is present and set to false, we exclude this if val, ok := filter["with_decisions"]; ok && val[0] == "false" { c.Log.Debugf("skipping decisions") } else { @@ -1085,9 +1046,9 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { - return nil, fmt.Errorf("unable to count nb alerts: %s", err) + return nil, fmt.Errorf("unable to count nb alerts: %w", err) } } @@ -1097,7 +1058,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } @@ -1126,35 +1087,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") @@ -1165,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1176,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1184,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1200,26 +1161,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 496b9b6cc9c..04ef830ae72 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -1,17 +1,48 @@ package database import ( + "context" "fmt" + "strings" "time" "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +type BouncerNotFoundError struct { + BouncerName string +} + +func (e *BouncerNotFoundError) Error() string { + return fmt.Sprintf("'%s' does not exist", e.BouncerName) +} + +func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { + os := baseMetrics.Os + features := strings.Join(baseMetrics.FeatureFlags, ",") + + _, err := c.Ent.Bouncer. + Update(). + Where(bouncer.NameEQ(bouncerName)). + SetNillableVersion(baseMetrics.Version). + SetOsname(*os.Name). + SetOsversion(*os.Version). + SetFeatureflags(features). + SetType(bouncerType). + Save(ctx) + if err != nil { + return fmt.Errorf("unable to update base bouncer metrics in database: %w", err) + } + + return nil +} + +func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) if err != nil { return nil, err } @@ -19,8 +50,8 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) if err != nil { return nil, err } @@ -28,85 +59,102 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } + return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) } - return nil, fmt.Errorf("unable to create bouncer: %s", err) + + return nil, fmt.Errorf("unable to create bouncer: %w", err) } + return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } if nbDeleted == 0 { - return fmt.Errorf("bouncer doesn't exist") + return &BouncerNotFoundError{BouncerName: name} } return nil } -func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { ids := make([]int, len(bouncers)) for i, b := range bouncers { ids[i] = b.ID } - nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX) + + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) if err != nil { - return nbDeleted, fmt.Errorf("unable to delete bouncers: %s", err) + return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) } + return nbDeleted, nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID). +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine last pull in database: %s", err) + return fmt.Errorf("unable to update machine last pull in database: %w", err) } + return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { - return fmt.Errorf("unable to update bouncer ip address in database: %s", err) + return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } + return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, ID int) error { - _, err := c.Ent.Bouncer.UpdateOneID(ID).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { - return fmt.Errorf("unable to update bouncer type and version in database: %s", err) + return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } + return nil } -func (c *Client) QueryBouncersLastPulltimeLT(t time.Time) ([]*ent.Bouncer, error) { - return c.Ent.Bouncer.Query().Where(bouncer.LastPullLT(t)).All(c.CTX) +func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { + return c.Ent.Bouncer.Query().Where( + // poor man's coalesce + bouncer.Or( + bouncer.LastPullLT(t), + bouncer.And( + bouncer.LastPullIsNil(), + bouncer.CreatedAtLT(t), + ), + ), + ).All(ctx) } diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..89ccb1e1b28 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,17 +1,20 @@ package database import ( + "context" + "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } + if err != nil { return nil, errors.Wrapf(QueryFail, "select config item: %s", err) } @@ -19,16 +22,16 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) - if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) + if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } } else if err != nil { return errors.Wrapf(QueryFail, "update config item: %s", err) } + return nil } diff --git a/pkg/database/database.go b/pkg/database/database.go index aa191d7dc43..bb41dd3b645 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -3,17 +3,17 @@ package database import ( "context" "database/sql" + "errors" "fmt" "os" entsql "entgo.io/ent/dialect/sql" + // load database backends _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/mattn/go-sqlite3" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -21,7 +21,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -34,72 +33,82 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig. if err != nil { return nil, err } - if config.MaxOpenConns == nil { - log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS) - config.MaxOpenConns = ptr.Of(csconfig.DEFAULT_MAX_OPEN_CONNS) + + if config.MaxOpenConns == 0 { + config.MaxOpenConns = csconfig.DEFAULT_MAX_OPEN_CONNS } - db.SetMaxOpenConns(*config.MaxOpenConns) + + db.SetMaxOpenConns(config.MaxOpenConns) drv := entsql.OpenDB(dbdialect, db) + return drv, nil } -func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { +func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, error) { var client *ent.Client - var err error + if config == nil { - return &Client{}, fmt.Errorf("DB config is empty") + return nil, errors.New("DB config is empty") } /*The logger that will be used by db operations*/ clog := log.New() if err := types.ConfigureLogger(clog); err != nil { return nil, fmt.Errorf("while configuring db logger: %w", err) } + if config.LogLevel != nil { clog.SetLevel(*config.LogLevel) } - entLogger := clog.WithField("context", "ent") + entLogger := clog.WithField("context", "ent") entOpt := ent.Log(entLogger.Debug) + typ, dia, err := config.ConnectionDialect() if err != nil { - return &Client{}, err //unsupported database caught here + return nil, err // unsupported database caught here } + if config.Type == "sqlite" { /*if it's the first startup, we want to touch and chmod file*/ - if _, err := os.Stat(config.DbPath); os.IsNotExist(err) { - f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600) + if _, err = os.Stat(config.DbPath); os.IsNotExist(err) { + f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0o600) if err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } + if err := f.Close(); err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } } - //Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) - if err := setFilePerm(config.DbPath, 0640); err != nil { - return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) + // Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) + if err = setFilePerm(config.DbPath, 0o640); err != nil { + return nil, fmt.Errorf("unable to set perms on %s: %w", config.DbPath, err) } } + drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err) + return nil, fmt.Errorf("failed opening connection to %s: %w", config.Type, err) } + client = ent.NewClient(ent.Driver(drv), entOpt) + if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel { clog.Debugf("Enabling request debug") + client = client.Debug() } - if err = client.Schema.Create(context.Background()); err != nil { - return nil, fmt.Errorf("failed creating schema resources: %v", err) + + if err = client.Schema.Create(ctx); err != nil { + return nil, fmt.Errorf("failed creating schema resources: %w", err) } return &Client{ - Ent: client, - CTX: context.Background(), - Log: clog, - CanFlush: true, - Type: config.Type, - WalMode: config.UseWal, + Ent: client, + Log: clog, + CanFlush: true, + Type: config.Type, + WalMode: config.UseWal, decisionBulkSize: config.DecisionBulkSize, }, nil } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index c4ea0bb119e..7522a272799 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strconv" "strings" @@ -17,6 +18,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) +const decisionDeleteBulkSize = 256 // scientifically proven to be the best value for bulk delete + type DecisionsByScenario struct { Scenario string Count int @@ -28,7 +31,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ @@ -37,6 +40,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] if v[0] == "false" { query = query.Where(decision.SimulatedEQ(false)) } + delete(filter, "simulated") } else { query = query.Where(decision.SimulatedEQ(false)) @@ -49,7 +53,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] if err != nil { return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } - case "scopes": + case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything scopes := strings.Split(value[0], ",") for i, scope := range scopes { switch strings.ToLower(scope) { @@ -63,6 +67,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] scopes[i] = types.AS } } + query = query.Where(decision.ScopeIn(scopes...)) case "value": query = query.Where(decision.ValueEQ(value[0])) @@ -107,23 +112,25 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] query = query.Where(decision.IDGT(id)) } } + query, err = applyStartIpEndIpFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) } + return query, nil } -func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { + +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } query, err := BuildDecisionRequestWithFilter(query, filters) - if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -131,19 +138,20 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") } + return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } @@ -156,20 +164,22 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") } - data, err := query.All(c.CTX) + + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") } + return data, nil } -func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*DecisionsByScenario, error) { +func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := BuildDecisionRequestWithFilter(query, make(map[string][]string)) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -177,8 +187,7 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D var r []*DecisionsByScenario - err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r) - + err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -187,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D return r, nil } -func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { var data []*ent.Decision var err error @@ -209,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec decision.FieldValue, decision.FieldScope, decision.FieldOrigin, - ).Scan(c.CTX, &data) + ).Scan(ctx, &data) if err != nil { c.Log.Warningf("QueryDecisionWithFilter : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed") @@ -246,15 +255,20 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), - decision.UntilGT(since), ) - //Allow a bouncer to ask for non-deduplicated results + + if since != nil { + query = query.Where(decision.UntilGT(*since)) + } + + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) @@ -263,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -272,15 +286,20 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( - decision.CreatedAtGT(since), decision.UntilGT(time.Now().UTC()), ) - //Allow a bouncer to ask for non-deduplicated results + + if since != nil { + query = query.Where(decision.CreatedAtGT(*since)) + } + + // Allow a bouncer to ask for non-deduplicated results if v, ok := filters["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) @@ -289,34 +308,25 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[ query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) } - return data, nil -} -func (c *Client) DeleteDecisionById(decisionId int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionId)).All(c.CTX) - if err != nil { - c.Log.Warningf("DeleteDecisionById : %s", err) - return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionId) - } - count, err := c.BulkDeleteDecisions(toDelete, false) - c.Log.Debugf("deleted %d decisions", count) - return toDelete, err + return data, nil } -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer) */ decisions := c.Ent.Decision.Query() + for param, value := range filter { switch param { case "contains": @@ -359,48 +369,48 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } else if ip_sz == 16 { if contains { /*decision contains {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), )) } else { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), @@ -410,28 +420,31 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.BulkDeleteDecisions(toDelete, false) + + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } + return strconv.Itoa(count), toDelete, nil } -// SoftDeleteDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC())) + for param, value := range filter { switch param { case "contains": @@ -480,24 +493,24 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri /*decision contains {start_ip,end_ip}*/ if contains { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), @@ -505,24 +518,24 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri } else { /*decision is contained within {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), @@ -531,64 +544,101 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri } else if ip_sz != 0 { return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + + DecisionsToDelete, err := decisions.All(ctx) if err != nil { - c.Log.Warningf("SoftDeleteDecisionsWithFilter : %s", err) - return "0", nil, errors.Wrap(DeleteFail, "soft delete decisions with provided filter") + c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) + return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.BulkDeleteDecisions(DecisionsToDelete, true) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { - return "0", nil, errors.Wrapf(DeleteFail, "soft delete decisions with provided filter : %s", err) + return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } + return strconv.Itoa(count), DecisionsToDelete, err } -// BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them. -// We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI -func (c *Client) BulkDeleteDecisions(decisionsToDelete []*ent.Decision, softDelete bool) (int, error) { - const bulkSize = 256 //scientifically proven to be the best value for bulk delete +func decisionIDs(decisions []*ent.Decision) []int { + ids := make([]int, len(decisions)) + for i, d := range decisions { + ids[i] = d.ID + } - var ( - nbUpdates int - err error - totalUpdates = 0 - ) + return ids +} + +// ExpireDecisions sets the expiration of a list of decisions to now() +// It returns the number of impacted decisions for the CAPI/PAPI +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { + if len(decisions) <= decisionDeleteBulkSize { + ids := decisionIDs(decisions) + + rows, err := c.Ent.Decision.Update().Where( + decision.IDIn(ids...), + ).SetUntil(time.Now().UTC()).Save(ctx) + if err != nil { + return 0, fmt.Errorf("expire decisions with provided filter: %w", err) + } - idsToDelete := make([]int, len(decisionsToDelete)) - for i, decision := range decisionsToDelete { - idsToDelete[i] = decision.ID + return rows, nil } - for _, chunk := range slicetools.Chunks(idsToDelete, bulkSize) { - if softDelete { - nbUpdates, err = c.Ent.Decision.Update().Where( - decision.IDIn(chunk...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) - if err != nil { - return totalUpdates, fmt.Errorf("soft delete decisions with provided filter: %w", err) - } - } else { - nbUpdates, err = c.Ent.Decision.Delete().Where( - decision.IDIn(chunk...), - ).Exec(c.CTX) - if err != nil { - return totalUpdates, fmt.Errorf("hard delete decisions with provided filter: %w", err) - } + // big batch, let's split it and recurse + + total := 0 + + for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { + rows, err := c.ExpireDecisions(ctx, chunk) + if err != nil { + return total, err + } + + total += rows + } + + return total, nil +} + +// DeleteDecisions removes a list of decisions from the database +// It returns the number of impacted decisions for the CAPI/PAPI +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { + if len(decisions) < decisionDeleteBulkSize { + ids := decisionIDs(decisions) + + rows, err := c.Ent.Decision.Delete().Where( + decision.IDIn(ids...), + ).Exec(ctx) + if err != nil { + return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) + } + + return rows, nil + } + + // big batch, let's split it and recurse + + tot := 0 + + for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { + rows, err := c.DeleteDecisions(ctx, chunk) + if err != nil { + return tot, err } - totalUpdates += nbUpdates + + tot += rows } - return totalUpdates, nil + return tot, nil } -// SoftDeleteDecisionByID set the expiration of a decision to now() -func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +// ExpireDecision set the expiration of a decision to now() +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { - c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate)) + c.Log.Warningf("ExpireDecisionByID : %v (nb expired: %d)", err, len(toUpdate)) return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) } @@ -596,28 +646,30 @@ func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, e return 0, nil, ItemNotFound } - count, err := c.BulkDeleteDecisions(toUpdate, true) + count, err := c.ExpireDecisions(ctx, toUpdate) + return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) } contains := true decisions := c.Ent.Decision.Query() + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -625,9 +677,70 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz, count int + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + if err != nil { + return 0, fmt.Errorf("unable to convert '%s' to int: %w", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query() + + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) + } + + decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) + + count, err = decisions.Count(ctx) + if err != nil { + return 0, fmt.Errorf("fail to count decisions: %w", err) + } + + return count, nil +} +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz int + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + if err != nil { + return 0, fmt.Errorf("unable to convert '%s' to int: %w", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query().Where( + decision.UntilGT(time.Now().UTC()), + ) + + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) + } + + decisions = decisions.Order(ent.Desc(decision.FieldUntil)) + + decision, err := decisions.First(ctx) + if err != nil && !ent.IsNotFound(err) { + return 0, fmt.Errorf("fail to get decision: %w", err) + } + + if decision == nil { + return 0, nil + } + + return decision.Until.Sub(time.Now().UTC()), nil +} + +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) } @@ -642,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err := decisions.Count(c.CTX) + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -667,6 +780,7 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz decision.IPSizeEQ(int64(ip_sz)), )) } + return decisions, nil } @@ -674,24 +788,24 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz /*decision contains {start_ip,end_ip}*/ if contains { decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip < query.start_ip + // decision.start_ip < query.start_ip decision.StartIPLT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix + // decision.start_suffix <= query.start_suffix decision.StartSuffixLTE(start_sfx), )), decision.Or( - //decision.end_ip > query.end_ip + // decision.end_ip > query.end_ip decision.EndIPGT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix + // decision.end_suffix >= query.end_suffix decision.EndSuffixGTE(end_sfx), ), ), @@ -699,29 +813,30 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz } else { /*decision is contained within {start_ip,end_ip}*/ decisions = decisions.Where(decision.And( - //matching addr size + // matching addr size decision.IPSizeEQ(int64(ip_sz)), decision.Or( - //decision.start_ip > query.start_ip + // decision.start_ip > query.start_ip decision.StartIPGT(start_ip), decision.And( - //decision.start_ip == query.start_ip + // decision.start_ip == query.start_ip decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix + // decision.start_suffix >= query.start_suffix decision.StartSuffixGTE(start_sfx), )), decision.Or( - //decision.end_ip < query.end_ip + // decision.end_ip < query.end_ip decision.EndIPLT(end_ip), decision.And( - //decision.end_ip == query.end_ip + // decision.end_ip == query.end_ip decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix + // decision.end_suffix <= query.end_suffix decision.EndSuffixLTE(end_sfx), ), ), )) } + return decisions, nil } @@ -735,8 +850,10 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision { words := strings.Split(s, ",") predicates := make([]predicate.Decision, len(words)) + for i, word := range words { predicates[i] = predicateFunc(word) } + return predicates } diff --git a/pkg/database/ent/alert.go b/pkg/database/ent/alert.go index 2649923bf5e..eb0e1cb7612 100644 --- a/pkg/database/ent/alert.go +++ b/pkg/database/ent/alert.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" @@ -18,9 +19,9 @@ type Alert struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Scenario holds the value of the "scenario" field. Scenario string `json:"scenario,omitempty"` // BucketId holds the value of the "bucketId" field. @@ -63,10 +64,13 @@ type Alert struct { Simulated bool `json:"simulated,omitempty"` // UUID holds the value of the "uuid" field. UUID string `json:"uuid,omitempty"` + // Remediation holds the value of the "remediation" field. + Remediation bool `json:"remediation,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the AlertQuery when eager-loading is set. Edges AlertEdges `json:"edges"` machine_alerts *int + selectValues sql.SelectValues } // AlertEdges holds the relations/edges for other nodes in the graph. @@ -87,12 +91,10 @@ type AlertEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e AlertEdges) OwnerOrErr() (*Machine, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: machine.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: machine.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -129,7 +131,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case alert.FieldSimulated: + case alert.FieldSimulated, alert.FieldRemediation: values[i] = new(sql.NullBool) case alert.FieldSourceLatitude, alert.FieldSourceLongitude: values[i] = new(sql.NullFloat64) @@ -142,7 +144,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) { case alert.ForeignKeys[0]: // machine_alerts values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Alert", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -166,15 +168,13 @@ func (a *Alert) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - a.CreatedAt = new(time.Time) - *a.CreatedAt = value.Time + a.CreatedAt = value.Time } case alert.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - a.UpdatedAt = new(time.Time) - *a.UpdatedAt = value.Time + a.UpdatedAt = value.Time } case alert.FieldScenario: if value, ok := values[i].(*sql.NullString); !ok { @@ -302,6 +302,12 @@ func (a *Alert) assignValues(columns []string, values []any) error { } else if value.Valid { a.UUID = value.String } + case alert.FieldRemediation: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field remediation", values[i]) + } else if value.Valid { + a.Remediation = value.Bool + } case alert.ForeignKeys[0]: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for edge-field machine_alerts", value) @@ -309,36 +315,44 @@ func (a *Alert) assignValues(columns []string, values []any) error { a.machine_alerts = new(int) *a.machine_alerts = int(value.Int64) } + default: + a.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Alert. +// This includes values selected through modifiers, order, etc. +func (a *Alert) Value(name string) (ent.Value, error) { + return a.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Alert entity. func (a *Alert) QueryOwner() *MachineQuery { - return (&AlertClient{config: a.config}).QueryOwner(a) + return NewAlertClient(a.config).QueryOwner(a) } // QueryDecisions queries the "decisions" edge of the Alert entity. func (a *Alert) QueryDecisions() *DecisionQuery { - return (&AlertClient{config: a.config}).QueryDecisions(a) + return NewAlertClient(a.config).QueryDecisions(a) } // QueryEvents queries the "events" edge of the Alert entity. func (a *Alert) QueryEvents() *EventQuery { - return (&AlertClient{config: a.config}).QueryEvents(a) + return NewAlertClient(a.config).QueryEvents(a) } // QueryMetas queries the "metas" edge of the Alert entity. func (a *Alert) QueryMetas() *MetaQuery { - return (&AlertClient{config: a.config}).QueryMetas(a) + return NewAlertClient(a.config).QueryMetas(a) } // Update returns a builder for updating this Alert. // Note that you need to call Alert.Unwrap() before calling this method if this Alert // was returned from a transaction, and the transaction was committed or rolled back. func (a *Alert) Update() *AlertUpdateOne { - return (&AlertClient{config: a.config}).UpdateOne(a) + return NewAlertClient(a.config).UpdateOne(a) } // Unwrap unwraps the Alert entity that was returned from a transaction after it was closed, @@ -357,15 +371,11 @@ func (a *Alert) String() string { var builder strings.Builder builder.WriteString("Alert(") builder.WriteString(fmt.Sprintf("id=%v, ", a.ID)) - if v := a.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(a.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := a.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(a.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("scenario=") builder.WriteString(a.Scenario) @@ -429,15 +439,12 @@ func (a *Alert) String() string { builder.WriteString(", ") builder.WriteString("uuid=") builder.WriteString(a.UUID) + builder.WriteString(", ") + builder.WriteString("remediation=") + builder.WriteString(fmt.Sprintf("%v", a.Remediation)) builder.WriteByte(')') return builder.String() } // Alerts is a parsable slice of Alert. type Alerts []*Alert - -func (a Alerts) config(cfg config) { - for _i := range a { - a[_i].config = cfg - } -} diff --git a/pkg/database/ent/alert/alert.go b/pkg/database/ent/alert/alert.go index abee13fb97a..62aade98e87 100644 --- a/pkg/database/ent/alert/alert.go +++ b/pkg/database/ent/alert/alert.go @@ -4,6 +4,9 @@ package alert import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,6 +60,8 @@ const ( FieldSimulated = "simulated" // FieldUUID holds the string denoting the uuid field in the database. FieldUUID = "uuid" + // FieldRemediation holds the string denoting the remediation field in the database. + FieldRemediation = "remediation" // EdgeOwner holds the string denoting the owner edge name in mutations. EdgeOwner = "owner" // EdgeDecisions holds the string denoting the decisions edge name in mutations. @@ -123,6 +128,7 @@ var Columns = []string{ FieldScenarioHash, FieldSimulated, FieldUUID, + FieldRemediation, } // ForeignKeys holds the SQL foreign-keys that are owned by the "alerts" @@ -149,8 +155,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -168,3 +172,208 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Alert queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByBucketId orders the results by the bucketId field. +func ByBucketId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBucketId, opts...).ToFunc() +} + +// ByMessage orders the results by the message field. +func ByMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMessage, opts...).ToFunc() +} + +// ByEventsCountField orders the results by the eventsCount field. +func ByEventsCountField(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEventsCount, opts...).ToFunc() +} + +// ByStartedAt orders the results by the startedAt field. +func ByStartedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartedAt, opts...).ToFunc() +} + +// ByStoppedAt orders the results by the stoppedAt field. +func ByStoppedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoppedAt, opts...).ToFunc() +} + +// BySourceIp orders the results by the sourceIp field. +func BySourceIp(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceIp, opts...).ToFunc() +} + +// BySourceRange orders the results by the sourceRange field. +func BySourceRange(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceRange, opts...).ToFunc() +} + +// BySourceAsNumber orders the results by the sourceAsNumber field. +func BySourceAsNumber(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsNumber, opts...).ToFunc() +} + +// BySourceAsName orders the results by the sourceAsName field. +func BySourceAsName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceAsName, opts...).ToFunc() +} + +// BySourceCountry orders the results by the sourceCountry field. +func BySourceCountry(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceCountry, opts...).ToFunc() +} + +// BySourceLatitude orders the results by the sourceLatitude field. +func BySourceLatitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLatitude, opts...).ToFunc() +} + +// BySourceLongitude orders the results by the sourceLongitude field. +func BySourceLongitude(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceLongitude, opts...).ToFunc() +} + +// BySourceScope orders the results by the sourceScope field. +func BySourceScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceScope, opts...).ToFunc() +} + +// BySourceValue orders the results by the sourceValue field. +func BySourceValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSourceValue, opts...).ToFunc() +} + +// ByCapacity orders the results by the capacity field. +func ByCapacity(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCapacity, opts...).ToFunc() +} + +// ByLeakSpeed orders the results by the leakSpeed field. +func ByLeakSpeed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLeakSpeed, opts...).ToFunc() +} + +// ByScenarioVersion orders the results by the scenarioVersion field. +func ByScenarioVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioVersion, opts...).ToFunc() +} + +// ByScenarioHash orders the results by the scenarioHash field. +func ByScenarioHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarioHash, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByRemediation orders the results by the remediation field. +func ByRemediation(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRemediation, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} + +// ByDecisionsCount orders the results by decisions count. +func ByDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDecisionsStep(), opts...) + } +} + +// ByDecisions orders the results by decisions terms. +func ByDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEventsCount orders the results by events count. +func ByEventsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEventsStep(), opts...) + } +} + +// ByEvents orders the results by events terms. +func ByEvents(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEventsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByMetasCount orders the results by metas count. +func ByMetasCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMetasStep(), opts...) + } +} + +// ByMetas orders the results by metas terms. +func ByMetas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMetasStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} +func newDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), + ) +} +func newEventsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EventsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), + ) +} +func newMetasStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MetasInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), + ) +} diff --git a/pkg/database/ent/alert/where.go b/pkg/database/ent/alert/where.go index ef5b89b615f..da6080fffb9 100644 --- a/pkg/database/ent/alert/where.go +++ b/pkg/database/ent/alert/where.go @@ -12,2440 +12,1617 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Alert(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // BucketId applies equality check predicate on the "bucketId" field. It's identical to BucketIdEQ. func BucketId(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // Message applies equality check predicate on the "message" field. It's identical to MessageEQ. func Message(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // EventsCount applies equality check predicate on the "eventsCount" field. It's identical to EventsCountEQ. func EventsCount(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // StartedAt applies equality check predicate on the "startedAt" field. It's identical to StartedAtEQ. func StartedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StoppedAt applies equality check predicate on the "stoppedAt" field. It's identical to StoppedAtEQ. func StoppedAt(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // SourceIp applies equality check predicate on the "sourceIp" field. It's identical to SourceIpEQ. func SourceIp(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceRange applies equality check predicate on the "sourceRange" field. It's identical to SourceRangeEQ. func SourceRange(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceAsNumber applies equality check predicate on the "sourceAsNumber" field. It's identical to SourceAsNumberEQ. func SourceAsNumber(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsName applies equality check predicate on the "sourceAsName" field. It's identical to SourceAsNameEQ. func SourceAsName(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceCountry applies equality check predicate on the "sourceCountry" field. It's identical to SourceCountryEQ. func SourceCountry(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceLatitude applies equality check predicate on the "sourceLatitude" field. It's identical to SourceLatitudeEQ. func SourceLatitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLongitude applies equality check predicate on the "sourceLongitude" field. It's identical to SourceLongitudeEQ. func SourceLongitude(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceScope applies equality check predicate on the "sourceScope" field. It's identical to SourceScopeEQ. func SourceScope(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceValue applies equality check predicate on the "sourceValue" field. It's identical to SourceValueEQ. func SourceValue(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // Capacity applies equality check predicate on the "capacity" field. It's identical to CapacityEQ. func Capacity(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // LeakSpeed applies equality check predicate on the "leakSpeed" field. It's identical to LeakSpeedEQ. func LeakSpeed(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // ScenarioVersion applies equality check predicate on the "scenarioVersion" field. It's identical to ScenarioVersionEQ. func ScenarioVersion(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioHash applies equality check predicate on the "scenarioHash" field. It's identical to ScenarioHashEQ. func ScenarioHash(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) +} + +// Remediation applies equality check predicate on the "remediation" field. It's identical to RemediationEQ. +func Remediation(v bool) predicate.Alert { + return predicate.Alert(sql.FieldEQ(FieldRemediation, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Alert(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Alert(sql.FieldLTE(FieldUpdatedAt, v)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenario, v)) } // BucketIdEQ applies the EQ predicate on the "bucketId" field. func BucketIdEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldBucketId, v)) } // BucketIdNEQ applies the NEQ predicate on the "bucketId" field. func BucketIdNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldBucketId, v)) } // BucketIdIn applies the In predicate on the "bucketId" field. func BucketIdIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldBucketId, vs...)) } // BucketIdNotIn applies the NotIn predicate on the "bucketId" field. func BucketIdNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBucketId), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldBucketId, vs...)) } // BucketIdGT applies the GT predicate on the "bucketId" field. func BucketIdGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGT(FieldBucketId, v)) } // BucketIdGTE applies the GTE predicate on the "bucketId" field. func BucketIdGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldBucketId, v)) } // BucketIdLT applies the LT predicate on the "bucketId" field. func BucketIdLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLT(FieldBucketId, v)) } // BucketIdLTE applies the LTE predicate on the "bucketId" field. func BucketIdLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldBucketId, v)) } // BucketIdContains applies the Contains predicate on the "bucketId" field. func BucketIdContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContains(FieldBucketId, v)) } // BucketIdHasPrefix applies the HasPrefix predicate on the "bucketId" field. func BucketIdHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldBucketId, v)) } // BucketIdHasSuffix applies the HasSuffix predicate on the "bucketId" field. func BucketIdHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldBucketId, v)) } // BucketIdIsNil applies the IsNil predicate on the "bucketId" field. func BucketIdIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldIsNull(FieldBucketId)) } // BucketIdNotNil applies the NotNil predicate on the "bucketId" field. func BucketIdNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldBucketId))) - }) + return predicate.Alert(sql.FieldNotNull(FieldBucketId)) } // BucketIdEqualFold applies the EqualFold predicate on the "bucketId" field. func BucketIdEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldBucketId, v)) } // BucketIdContainsFold applies the ContainsFold predicate on the "bucketId" field. func BucketIdContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBucketId), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldBucketId, v)) } // MessageEQ applies the EQ predicate on the "message" field. func MessageEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldMessage, v)) } // MessageNEQ applies the NEQ predicate on the "message" field. func MessageNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldMessage, v)) } // MessageIn applies the In predicate on the "message" field. func MessageIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldMessage, vs...)) } // MessageNotIn applies the NotIn predicate on the "message" field. func MessageNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMessage), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldMessage, vs...)) } // MessageGT applies the GT predicate on the "message" field. func MessageGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGT(FieldMessage, v)) } // MessageGTE applies the GTE predicate on the "message" field. func MessageGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldMessage, v)) } // MessageLT applies the LT predicate on the "message" field. func MessageLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLT(FieldMessage, v)) } // MessageLTE applies the LTE predicate on the "message" field. func MessageLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldMessage, v)) } // MessageContains applies the Contains predicate on the "message" field. func MessageContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContains(FieldMessage, v)) } // MessageHasPrefix applies the HasPrefix predicate on the "message" field. func MessageHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldMessage, v)) } // MessageHasSuffix applies the HasSuffix predicate on the "message" field. func MessageHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldMessage, v)) } // MessageIsNil applies the IsNil predicate on the "message" field. func MessageIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldIsNull(FieldMessage)) } // MessageNotNil applies the NotNil predicate on the "message" field. func MessageNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldMessage))) - }) + return predicate.Alert(sql.FieldNotNull(FieldMessage)) } // MessageEqualFold applies the EqualFold predicate on the "message" field. func MessageEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldMessage, v)) } // MessageContainsFold applies the ContainsFold predicate on the "message" field. func MessageContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMessage), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldMessage, v)) } // EventsCountEQ applies the EQ predicate on the "eventsCount" field. func EventsCountEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldEventsCount, v)) } // EventsCountNEQ applies the NEQ predicate on the "eventsCount" field. func EventsCountNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldEventsCount, v)) } // EventsCountIn applies the In predicate on the "eventsCount" field. func EventsCountIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldEventsCount, vs...)) } // EventsCountNotIn applies the NotIn predicate on the "eventsCount" field. func EventsCountNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEventsCount), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldEventsCount, vs...)) } // EventsCountGT applies the GT predicate on the "eventsCount" field. func EventsCountGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGT(FieldEventsCount, v)) } // EventsCountGTE applies the GTE predicate on the "eventsCount" field. func EventsCountGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldEventsCount, v)) } // EventsCountLT applies the LT predicate on the "eventsCount" field. func EventsCountLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLT(FieldEventsCount, v)) } // EventsCountLTE applies the LTE predicate on the "eventsCount" field. func EventsCountLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEventsCount), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldEventsCount, v)) } // EventsCountIsNil applies the IsNil predicate on the "eventsCount" field. func EventsCountIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldIsNull(FieldEventsCount)) } // EventsCountNotNil applies the NotNil predicate on the "eventsCount" field. func EventsCountNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEventsCount))) - }) + return predicate.Alert(sql.FieldNotNull(FieldEventsCount)) } // StartedAtEQ applies the EQ predicate on the "startedAt" field. func StartedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStartedAt, v)) } // StartedAtNEQ applies the NEQ predicate on the "startedAt" field. func StartedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStartedAt, v)) } // StartedAtIn applies the In predicate on the "startedAt" field. func StartedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStartedAt, vs...)) } // StartedAtNotIn applies the NotIn predicate on the "startedAt" field. func StartedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStartedAt, vs...)) } // StartedAtGT applies the GT predicate on the "startedAt" field. func StartedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStartedAt, v)) } // StartedAtGTE applies the GTE predicate on the "startedAt" field. func StartedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStartedAt, v)) } // StartedAtLT applies the LT predicate on the "startedAt" field. func StartedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStartedAt, v)) } // StartedAtLTE applies the LTE predicate on the "startedAt" field. func StartedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStartedAt, v)) } // StartedAtIsNil applies the IsNil predicate on the "startedAt" field. func StartedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStartedAt)) } // StartedAtNotNil applies the NotNil predicate on the "startedAt" field. func StartedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStartedAt)) } // StoppedAtEQ applies the EQ predicate on the "stoppedAt" field. func StoppedAtEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldStoppedAt, v)) } // StoppedAtNEQ applies the NEQ predicate on the "stoppedAt" field. func StoppedAtNEQ(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldStoppedAt, v)) } // StoppedAtIn applies the In predicate on the "stoppedAt" field. func StoppedAtIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldStoppedAt, vs...)) } // StoppedAtNotIn applies the NotIn predicate on the "stoppedAt" field. func StoppedAtNotIn(vs ...time.Time) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStoppedAt), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldStoppedAt, vs...)) } // StoppedAtGT applies the GT predicate on the "stoppedAt" field. func StoppedAtGT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGT(FieldStoppedAt, v)) } // StoppedAtGTE applies the GTE predicate on the "stoppedAt" field. func StoppedAtGTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldStoppedAt, v)) } // StoppedAtLT applies the LT predicate on the "stoppedAt" field. func StoppedAtLT(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLT(FieldStoppedAt, v)) } // StoppedAtLTE applies the LTE predicate on the "stoppedAt" field. func StoppedAtLTE(v time.Time) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStoppedAt), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldStoppedAt, v)) } // StoppedAtIsNil applies the IsNil predicate on the "stoppedAt" field. func StoppedAtIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldIsNull(FieldStoppedAt)) } // StoppedAtNotNil applies the NotNil predicate on the "stoppedAt" field. func StoppedAtNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStoppedAt))) - }) + return predicate.Alert(sql.FieldNotNull(FieldStoppedAt)) } // SourceIpEQ applies the EQ predicate on the "sourceIp" field. func SourceIpEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceIp, v)) } // SourceIpNEQ applies the NEQ predicate on the "sourceIp" field. func SourceIpNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceIp, v)) } // SourceIpIn applies the In predicate on the "sourceIp" field. func SourceIpIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceIp, vs...)) } // SourceIpNotIn applies the NotIn predicate on the "sourceIp" field. func SourceIpNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceIp), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceIp, vs...)) } // SourceIpGT applies the GT predicate on the "sourceIp" field. func SourceIpGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceIp, v)) } // SourceIpGTE applies the GTE predicate on the "sourceIp" field. func SourceIpGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceIp, v)) } // SourceIpLT applies the LT predicate on the "sourceIp" field. func SourceIpLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceIp, v)) } // SourceIpLTE applies the LTE predicate on the "sourceIp" field. func SourceIpLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceIp, v)) } // SourceIpContains applies the Contains predicate on the "sourceIp" field. func SourceIpContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceIp, v)) } // SourceIpHasPrefix applies the HasPrefix predicate on the "sourceIp" field. func SourceIpHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceIp, v)) } // SourceIpHasSuffix applies the HasSuffix predicate on the "sourceIp" field. func SourceIpHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceIp, v)) } // SourceIpIsNil applies the IsNil predicate on the "sourceIp" field. func SourceIpIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceIp)) } // SourceIpNotNil applies the NotNil predicate on the "sourceIp" field. func SourceIpNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceIp))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceIp)) } // SourceIpEqualFold applies the EqualFold predicate on the "sourceIp" field. func SourceIpEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceIp, v)) } // SourceIpContainsFold applies the ContainsFold predicate on the "sourceIp" field. func SourceIpContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceIp), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceIp, v)) } // SourceRangeEQ applies the EQ predicate on the "sourceRange" field. func SourceRangeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceRange, v)) } // SourceRangeNEQ applies the NEQ predicate on the "sourceRange" field. func SourceRangeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceRange, v)) } // SourceRangeIn applies the In predicate on the "sourceRange" field. func SourceRangeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceRange, vs...)) } // SourceRangeNotIn applies the NotIn predicate on the "sourceRange" field. func SourceRangeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceRange), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceRange, vs...)) } // SourceRangeGT applies the GT predicate on the "sourceRange" field. func SourceRangeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceRange, v)) } // SourceRangeGTE applies the GTE predicate on the "sourceRange" field. func SourceRangeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceRange, v)) } // SourceRangeLT applies the LT predicate on the "sourceRange" field. func SourceRangeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceRange, v)) } // SourceRangeLTE applies the LTE predicate on the "sourceRange" field. func SourceRangeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceRange, v)) } // SourceRangeContains applies the Contains predicate on the "sourceRange" field. func SourceRangeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceRange, v)) } // SourceRangeHasPrefix applies the HasPrefix predicate on the "sourceRange" field. func SourceRangeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceRange, v)) } // SourceRangeHasSuffix applies the HasSuffix predicate on the "sourceRange" field. func SourceRangeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceRange, v)) } // SourceRangeIsNil applies the IsNil predicate on the "sourceRange" field. func SourceRangeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceRange)) } // SourceRangeNotNil applies the NotNil predicate on the "sourceRange" field. func SourceRangeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceRange))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceRange)) } // SourceRangeEqualFold applies the EqualFold predicate on the "sourceRange" field. func SourceRangeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceRange, v)) } // SourceRangeContainsFold applies the ContainsFold predicate on the "sourceRange" field. func SourceRangeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceRange), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceRange, v)) } // SourceAsNumberEQ applies the EQ predicate on the "sourceAsNumber" field. func SourceAsNumberEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsNumber, v)) } // SourceAsNumberNEQ applies the NEQ predicate on the "sourceAsNumber" field. func SourceAsNumberNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsNumber, v)) } // SourceAsNumberIn applies the In predicate on the "sourceAsNumber" field. func SourceAsNumberIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberNotIn applies the NotIn predicate on the "sourceAsNumber" field. func SourceAsNumberNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsNumber), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsNumber, vs...)) } // SourceAsNumberGT applies the GT predicate on the "sourceAsNumber" field. func SourceAsNumberGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsNumber, v)) } // SourceAsNumberGTE applies the GTE predicate on the "sourceAsNumber" field. func SourceAsNumberGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsNumber, v)) } // SourceAsNumberLT applies the LT predicate on the "sourceAsNumber" field. func SourceAsNumberLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsNumber, v)) } // SourceAsNumberLTE applies the LTE predicate on the "sourceAsNumber" field. func SourceAsNumberLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsNumber, v)) } // SourceAsNumberContains applies the Contains predicate on the "sourceAsNumber" field. func SourceAsNumberContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsNumber, v)) } // SourceAsNumberHasPrefix applies the HasPrefix predicate on the "sourceAsNumber" field. func SourceAsNumberHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsNumber, v)) } // SourceAsNumberHasSuffix applies the HasSuffix predicate on the "sourceAsNumber" field. func SourceAsNumberHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsNumber, v)) } // SourceAsNumberIsNil applies the IsNil predicate on the "sourceAsNumber" field. func SourceAsNumberIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsNumber)) } // SourceAsNumberNotNil applies the NotNil predicate on the "sourceAsNumber" field. func SourceAsNumberNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsNumber))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsNumber)) } // SourceAsNumberEqualFold applies the EqualFold predicate on the "sourceAsNumber" field. func SourceAsNumberEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsNumber, v)) } // SourceAsNumberContainsFold applies the ContainsFold predicate on the "sourceAsNumber" field. func SourceAsNumberContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsNumber), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsNumber, v)) } // SourceAsNameEQ applies the EQ predicate on the "sourceAsName" field. func SourceAsNameEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceAsName, v)) } // SourceAsNameNEQ applies the NEQ predicate on the "sourceAsName" field. func SourceAsNameNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceAsName, v)) } // SourceAsNameIn applies the In predicate on the "sourceAsName" field. func SourceAsNameIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceAsName, vs...)) } // SourceAsNameNotIn applies the NotIn predicate on the "sourceAsName" field. func SourceAsNameNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceAsName), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceAsName, vs...)) } // SourceAsNameGT applies the GT predicate on the "sourceAsName" field. func SourceAsNameGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceAsName, v)) } // SourceAsNameGTE applies the GTE predicate on the "sourceAsName" field. func SourceAsNameGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceAsName, v)) } // SourceAsNameLT applies the LT predicate on the "sourceAsName" field. func SourceAsNameLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceAsName, v)) } // SourceAsNameLTE applies the LTE predicate on the "sourceAsName" field. func SourceAsNameLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceAsName, v)) } // SourceAsNameContains applies the Contains predicate on the "sourceAsName" field. func SourceAsNameContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceAsName, v)) } // SourceAsNameHasPrefix applies the HasPrefix predicate on the "sourceAsName" field. func SourceAsNameHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceAsName, v)) } // SourceAsNameHasSuffix applies the HasSuffix predicate on the "sourceAsName" field. func SourceAsNameHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceAsName, v)) } // SourceAsNameIsNil applies the IsNil predicate on the "sourceAsName" field. func SourceAsNameIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceAsName)) } // SourceAsNameNotNil applies the NotNil predicate on the "sourceAsName" field. func SourceAsNameNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceAsName))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceAsName)) } // SourceAsNameEqualFold applies the EqualFold predicate on the "sourceAsName" field. func SourceAsNameEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceAsName, v)) } // SourceAsNameContainsFold applies the ContainsFold predicate on the "sourceAsName" field. func SourceAsNameContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceAsName), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceAsName, v)) } // SourceCountryEQ applies the EQ predicate on the "sourceCountry" field. func SourceCountryEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceCountry, v)) } // SourceCountryNEQ applies the NEQ predicate on the "sourceCountry" field. func SourceCountryNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceCountry, v)) } // SourceCountryIn applies the In predicate on the "sourceCountry" field. func SourceCountryIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceCountry, vs...)) } // SourceCountryNotIn applies the NotIn predicate on the "sourceCountry" field. func SourceCountryNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceCountry), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceCountry, vs...)) } // SourceCountryGT applies the GT predicate on the "sourceCountry" field. func SourceCountryGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceCountry, v)) } // SourceCountryGTE applies the GTE predicate on the "sourceCountry" field. func SourceCountryGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceCountry, v)) } // SourceCountryLT applies the LT predicate on the "sourceCountry" field. func SourceCountryLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceCountry, v)) } // SourceCountryLTE applies the LTE predicate on the "sourceCountry" field. func SourceCountryLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceCountry, v)) } // SourceCountryContains applies the Contains predicate on the "sourceCountry" field. func SourceCountryContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceCountry, v)) } // SourceCountryHasPrefix applies the HasPrefix predicate on the "sourceCountry" field. func SourceCountryHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceCountry, v)) } // SourceCountryHasSuffix applies the HasSuffix predicate on the "sourceCountry" field. func SourceCountryHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceCountry, v)) } // SourceCountryIsNil applies the IsNil predicate on the "sourceCountry" field. func SourceCountryIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceCountry)) } // SourceCountryNotNil applies the NotNil predicate on the "sourceCountry" field. func SourceCountryNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceCountry))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceCountry)) } // SourceCountryEqualFold applies the EqualFold predicate on the "sourceCountry" field. func SourceCountryEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceCountry, v)) } // SourceCountryContainsFold applies the ContainsFold predicate on the "sourceCountry" field. func SourceCountryContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceCountry), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceCountry, v)) } // SourceLatitudeEQ applies the EQ predicate on the "sourceLatitude" field. func SourceLatitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLatitude, v)) } // SourceLatitudeNEQ applies the NEQ predicate on the "sourceLatitude" field. func SourceLatitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLatitude, v)) } // SourceLatitudeIn applies the In predicate on the "sourceLatitude" field. func SourceLatitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLatitude, vs...)) } // SourceLatitudeNotIn applies the NotIn predicate on the "sourceLatitude" field. func SourceLatitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLatitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLatitude, vs...)) } // SourceLatitudeGT applies the GT predicate on the "sourceLatitude" field. func SourceLatitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLatitude, v)) } // SourceLatitudeGTE applies the GTE predicate on the "sourceLatitude" field. func SourceLatitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceLatitude, v)) } // SourceLatitudeLT applies the LT predicate on the "sourceLatitude" field. func SourceLatitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLatitude, v)) } // SourceLatitudeLTE applies the LTE predicate on the "sourceLatitude" field. func SourceLatitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLatitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLatitude, v)) } // SourceLatitudeIsNil applies the IsNil predicate on the "sourceLatitude" field. func SourceLatitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLatitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLatitude)) } // SourceLatitudeNotNil applies the NotNil predicate on the "sourceLatitude" field. -func SourceLatitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLatitude))) - }) +func SourceLatitudeNotNil() predicate.Alert { + return predicate.Alert(sql.FieldNotNull(FieldSourceLatitude)) } // SourceLongitudeEQ applies the EQ predicate on the "sourceLongitude" field. func SourceLongitudeEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceLongitude, v)) } // SourceLongitudeNEQ applies the NEQ predicate on the "sourceLongitude" field. func SourceLongitudeNEQ(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceLongitude, v)) } // SourceLongitudeIn applies the In predicate on the "sourceLongitude" field. func SourceLongitudeIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceLongitude, vs...)) } // SourceLongitudeNotIn applies the NotIn predicate on the "sourceLongitude" field. func SourceLongitudeNotIn(vs ...float32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceLongitude), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceLongitude, vs...)) } // SourceLongitudeGT applies the GT predicate on the "sourceLongitude" field. func SourceLongitudeGT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceLongitude, v)) } // SourceLongitudeGTE applies the GTE predicate on the "sourceLongitude" field. func SourceLongitudeGTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceLongitude, v)) } // SourceLongitudeLT applies the LT predicate on the "sourceLongitude" field. func SourceLongitudeLT(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceLongitude, v)) } // SourceLongitudeLTE applies the LTE predicate on the "sourceLongitude" field. func SourceLongitudeLTE(v float32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceLongitude), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceLongitude, v)) } // SourceLongitudeIsNil applies the IsNil predicate on the "sourceLongitude" field. func SourceLongitudeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceLongitude)) } // SourceLongitudeNotNil applies the NotNil predicate on the "sourceLongitude" field. func SourceLongitudeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceLongitude))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceLongitude)) } // SourceScopeEQ applies the EQ predicate on the "sourceScope" field. func SourceScopeEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceScope, v)) } // SourceScopeNEQ applies the NEQ predicate on the "sourceScope" field. func SourceScopeNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceScope, v)) } // SourceScopeIn applies the In predicate on the "sourceScope" field. func SourceScopeIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceScope, vs...)) } // SourceScopeNotIn applies the NotIn predicate on the "sourceScope" field. func SourceScopeNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceScope), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceScope, vs...)) } // SourceScopeGT applies the GT predicate on the "sourceScope" field. func SourceScopeGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceScope, v)) } // SourceScopeGTE applies the GTE predicate on the "sourceScope" field. func SourceScopeGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceScope, v)) } // SourceScopeLT applies the LT predicate on the "sourceScope" field. func SourceScopeLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceScope, v)) } // SourceScopeLTE applies the LTE predicate on the "sourceScope" field. func SourceScopeLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceScope, v)) } // SourceScopeContains applies the Contains predicate on the "sourceScope" field. func SourceScopeContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceScope, v)) } // SourceScopeHasPrefix applies the HasPrefix predicate on the "sourceScope" field. func SourceScopeHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceScope, v)) } // SourceScopeHasSuffix applies the HasSuffix predicate on the "sourceScope" field. func SourceScopeHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceScope, v)) } // SourceScopeIsNil applies the IsNil predicate on the "sourceScope" field. func SourceScopeIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceScope)) } // SourceScopeNotNil applies the NotNil predicate on the "sourceScope" field. func SourceScopeNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceScope))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceScope)) } // SourceScopeEqualFold applies the EqualFold predicate on the "sourceScope" field. func SourceScopeEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceScope, v)) } // SourceScopeContainsFold applies the ContainsFold predicate on the "sourceScope" field. func SourceScopeContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceScope), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceScope, v)) } // SourceValueEQ applies the EQ predicate on the "sourceValue" field. func SourceValueEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSourceValue, v)) } // SourceValueNEQ applies the NEQ predicate on the "sourceValue" field. func SourceValueNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSourceValue, v)) } // SourceValueIn applies the In predicate on the "sourceValue" field. func SourceValueIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldSourceValue, vs...)) } // SourceValueNotIn applies the NotIn predicate on the "sourceValue" field. func SourceValueNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSourceValue), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldSourceValue, vs...)) } // SourceValueGT applies the GT predicate on the "sourceValue" field. func SourceValueGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGT(FieldSourceValue, v)) } // SourceValueGTE applies the GTE predicate on the "sourceValue" field. func SourceValueGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldSourceValue, v)) } // SourceValueLT applies the LT predicate on the "sourceValue" field. func SourceValueLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLT(FieldSourceValue, v)) } // SourceValueLTE applies the LTE predicate on the "sourceValue" field. func SourceValueLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldSourceValue, v)) } // SourceValueContains applies the Contains predicate on the "sourceValue" field. func SourceValueContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContains(FieldSourceValue, v)) } // SourceValueHasPrefix applies the HasPrefix predicate on the "sourceValue" field. func SourceValueHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldSourceValue, v)) } // SourceValueHasSuffix applies the HasSuffix predicate on the "sourceValue" field. func SourceValueHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldSourceValue, v)) } // SourceValueIsNil applies the IsNil predicate on the "sourceValue" field. func SourceValueIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldIsNull(FieldSourceValue)) } // SourceValueNotNil applies the NotNil predicate on the "sourceValue" field. func SourceValueNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldSourceValue))) - }) + return predicate.Alert(sql.FieldNotNull(FieldSourceValue)) } // SourceValueEqualFold applies the EqualFold predicate on the "sourceValue" field. func SourceValueEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldSourceValue, v)) } // SourceValueContainsFold applies the ContainsFold predicate on the "sourceValue" field. func SourceValueContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSourceValue), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldSourceValue, v)) } // CapacityEQ applies the EQ predicate on the "capacity" field. func CapacityEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldCapacity, v)) } // CapacityNEQ applies the NEQ predicate on the "capacity" field. func CapacityNEQ(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldCapacity, v)) } // CapacityIn applies the In predicate on the "capacity" field. func CapacityIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldCapacity, vs...)) } // CapacityNotIn applies the NotIn predicate on the "capacity" field. func CapacityNotIn(vs ...int32) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCapacity), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldCapacity, vs...)) } // CapacityGT applies the GT predicate on the "capacity" field. func CapacityGT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGT(FieldCapacity, v)) } // CapacityGTE applies the GTE predicate on the "capacity" field. func CapacityGTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldCapacity, v)) } // CapacityLT applies the LT predicate on the "capacity" field. func CapacityLT(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLT(FieldCapacity, v)) } // CapacityLTE applies the LTE predicate on the "capacity" field. func CapacityLTE(v int32) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCapacity), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldCapacity, v)) } // CapacityIsNil applies the IsNil predicate on the "capacity" field. func CapacityIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldIsNull(FieldCapacity)) } // CapacityNotNil applies the NotNil predicate on the "capacity" field. func CapacityNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCapacity))) - }) + return predicate.Alert(sql.FieldNotNull(FieldCapacity)) } // LeakSpeedEQ applies the EQ predicate on the "leakSpeed" field. func LeakSpeedEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldLeakSpeed, v)) } // LeakSpeedNEQ applies the NEQ predicate on the "leakSpeed" field. func LeakSpeedNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldLeakSpeed, v)) } // LeakSpeedIn applies the In predicate on the "leakSpeed" field. func LeakSpeedIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldLeakSpeed, vs...)) } // LeakSpeedNotIn applies the NotIn predicate on the "leakSpeed" field. func LeakSpeedNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLeakSpeed), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldLeakSpeed, vs...)) } // LeakSpeedGT applies the GT predicate on the "leakSpeed" field. func LeakSpeedGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGT(FieldLeakSpeed, v)) } // LeakSpeedGTE applies the GTE predicate on the "leakSpeed" field. func LeakSpeedGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldLeakSpeed, v)) } // LeakSpeedLT applies the LT predicate on the "leakSpeed" field. func LeakSpeedLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLT(FieldLeakSpeed, v)) } // LeakSpeedLTE applies the LTE predicate on the "leakSpeed" field. func LeakSpeedLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldLeakSpeed, v)) } // LeakSpeedContains applies the Contains predicate on the "leakSpeed" field. func LeakSpeedContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContains(FieldLeakSpeed, v)) } // LeakSpeedHasPrefix applies the HasPrefix predicate on the "leakSpeed" field. func LeakSpeedHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldLeakSpeed, v)) } // LeakSpeedHasSuffix applies the HasSuffix predicate on the "leakSpeed" field. func LeakSpeedHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldLeakSpeed, v)) } // LeakSpeedIsNil applies the IsNil predicate on the "leakSpeed" field. func LeakSpeedIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldIsNull(FieldLeakSpeed)) } // LeakSpeedNotNil applies the NotNil predicate on the "leakSpeed" field. func LeakSpeedNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLeakSpeed))) - }) + return predicate.Alert(sql.FieldNotNull(FieldLeakSpeed)) } // LeakSpeedEqualFold applies the EqualFold predicate on the "leakSpeed" field. func LeakSpeedEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldLeakSpeed, v)) } // LeakSpeedContainsFold applies the ContainsFold predicate on the "leakSpeed" field. func LeakSpeedContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldLeakSpeed), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldLeakSpeed, v)) } // ScenarioVersionEQ applies the EQ predicate on the "scenarioVersion" field. func ScenarioVersionEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioVersion, v)) } // ScenarioVersionNEQ applies the NEQ predicate on the "scenarioVersion" field. func ScenarioVersionNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioVersion, v)) } // ScenarioVersionIn applies the In predicate on the "scenarioVersion" field. func ScenarioVersionIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioVersion, vs...)) } // ScenarioVersionNotIn applies the NotIn predicate on the "scenarioVersion" field. func ScenarioVersionNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioVersion), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioVersion, vs...)) } // ScenarioVersionGT applies the GT predicate on the "scenarioVersion" field. func ScenarioVersionGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioVersion, v)) } // ScenarioVersionGTE applies the GTE predicate on the "scenarioVersion" field. func ScenarioVersionGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioVersion, v)) } // ScenarioVersionLT applies the LT predicate on the "scenarioVersion" field. func ScenarioVersionLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioVersion, v)) } // ScenarioVersionLTE applies the LTE predicate on the "scenarioVersion" field. func ScenarioVersionLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioVersion, v)) } // ScenarioVersionContains applies the Contains predicate on the "scenarioVersion" field. func ScenarioVersionContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioVersion, v)) } // ScenarioVersionHasPrefix applies the HasPrefix predicate on the "scenarioVersion" field. func ScenarioVersionHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioVersion, v)) } // ScenarioVersionHasSuffix applies the HasSuffix predicate on the "scenarioVersion" field. func ScenarioVersionHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioVersion, v)) } // ScenarioVersionIsNil applies the IsNil predicate on the "scenarioVersion" field. func ScenarioVersionIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioVersion)) } // ScenarioVersionNotNil applies the NotNil predicate on the "scenarioVersion" field. func ScenarioVersionNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioVersion))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioVersion)) } // ScenarioVersionEqualFold applies the EqualFold predicate on the "scenarioVersion" field. func ScenarioVersionEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioVersion, v)) } // ScenarioVersionContainsFold applies the ContainsFold predicate on the "scenarioVersion" field. func ScenarioVersionContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioVersion), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioVersion, v)) } // ScenarioHashEQ applies the EQ predicate on the "scenarioHash" field. func ScenarioHashEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldScenarioHash, v)) } // ScenarioHashNEQ applies the NEQ predicate on the "scenarioHash" field. func ScenarioHashNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldScenarioHash, v)) } // ScenarioHashIn applies the In predicate on the "scenarioHash" field. func ScenarioHashIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldScenarioHash, vs...)) } // ScenarioHashNotIn applies the NotIn predicate on the "scenarioHash" field. func ScenarioHashNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarioHash), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldScenarioHash, vs...)) } // ScenarioHashGT applies the GT predicate on the "scenarioHash" field. func ScenarioHashGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGT(FieldScenarioHash, v)) } // ScenarioHashGTE applies the GTE predicate on the "scenarioHash" field. func ScenarioHashGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldScenarioHash, v)) } // ScenarioHashLT applies the LT predicate on the "scenarioHash" field. func ScenarioHashLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLT(FieldScenarioHash, v)) } // ScenarioHashLTE applies the LTE predicate on the "scenarioHash" field. func ScenarioHashLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldScenarioHash, v)) } // ScenarioHashContains applies the Contains predicate on the "scenarioHash" field. func ScenarioHashContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContains(FieldScenarioHash, v)) } // ScenarioHashHasPrefix applies the HasPrefix predicate on the "scenarioHash" field. func ScenarioHashHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldScenarioHash, v)) } // ScenarioHashHasSuffix applies the HasSuffix predicate on the "scenarioHash" field. func ScenarioHashHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldScenarioHash, v)) } // ScenarioHashIsNil applies the IsNil predicate on the "scenarioHash" field. func ScenarioHashIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldIsNull(FieldScenarioHash)) } // ScenarioHashNotNil applies the NotNil predicate on the "scenarioHash" field. func ScenarioHashNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarioHash))) - }) + return predicate.Alert(sql.FieldNotNull(FieldScenarioHash)) } // ScenarioHashEqualFold applies the EqualFold predicate on the "scenarioHash" field. func ScenarioHashEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldScenarioHash, v)) } // ScenarioHashContainsFold applies the ContainsFold predicate on the "scenarioHash" field. func ScenarioHashContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarioHash), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldScenarioHash, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Alert { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Alert(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Alert(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Alert(sql.FieldContainsFold(FieldUUID, v)) +} + +// RemediationEQ applies the EQ predicate on the "remediation" field. +func RemediationEQ(v bool) predicate.Alert { + return predicate.Alert(sql.FieldEQ(FieldRemediation, v)) +} + +// RemediationNEQ applies the NEQ predicate on the "remediation" field. +func RemediationNEQ(v bool) predicate.Alert { + return predicate.Alert(sql.FieldNEQ(FieldRemediation, v)) +} + +// RemediationIsNil applies the IsNil predicate on the "remediation" field. +func RemediationIsNil() predicate.Alert { + return predicate.Alert(sql.FieldIsNull(FieldRemediation)) +} + +// RemediationNotNil applies the NotNil predicate on the "remediation" field. +func RemediationNotNil() predicate.Alert { + return predicate.Alert(sql.FieldNotNull(FieldRemediation)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -2453,7 +1630,6 @@ func HasOwner() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2463,11 +1639,7 @@ func HasOwner() predicate.Alert { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Machine) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2481,7 +1653,6 @@ func HasDecisions() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2491,11 +1662,7 @@ func HasDecisions() predicate.Alert { // HasDecisionsWith applies the HasEdge predicate on the "decisions" edge with a given conditions (other predicates). func HasDecisionsWith(preds ...predicate.Decision) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DecisionsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn), - ) + step := newDecisionsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2509,7 +1676,6 @@ func HasEvents() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2519,11 +1685,7 @@ func HasEvents() predicate.Alert { // HasEventsWith applies the HasEdge predicate on the "events" edge with a given conditions (other predicates). func HasEventsWith(preds ...predicate.Event) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(EventsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn), - ) + step := newEventsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2537,7 +1699,6 @@ func HasMetas() predicate.Alert { return predicate.Alert(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), ) sqlgraph.HasNeighbors(s, step) @@ -2547,11 +1708,7 @@ func HasMetas() predicate.Alert { // HasMetasWith applies the HasEdge predicate on the "metas" edge with a given conditions (other predicates). func HasMetasWith(preds ...predicate.Meta) predicate.Alert { return predicate.Alert(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(MetasInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn), - ) + step := newMetasStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -2562,32 +1719,15 @@ func HasMetasWith(preds ...predicate.Meta) predicate.Alert { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Alert(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Alert) predicate.Alert { - return predicate.Alert(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Alert(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/alert_create.go b/pkg/database/ent/alert_create.go index 42da5b137ba..753183a9eb9 100644 --- a/pkg/database/ent/alert_create.go +++ b/pkg/database/ent/alert_create.go @@ -338,6 +338,20 @@ func (ac *AlertCreate) SetNillableUUID(s *string) *AlertCreate { return ac } +// SetRemediation sets the "remediation" field. +func (ac *AlertCreate) SetRemediation(b bool) *AlertCreate { + ac.mutation.SetRemediation(b) + return ac +} + +// SetNillableRemediation sets the "remediation" field if the given value is not nil. +func (ac *AlertCreate) SetNillableRemediation(b *bool) *AlertCreate { + if b != nil { + ac.SetRemediation(*b) + } + return ac +} + // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (ac *AlertCreate) SetOwnerID(id int) *AlertCreate { ac.mutation.SetOwnerID(id) @@ -409,50 +423,8 @@ func (ac *AlertCreate) Mutation() *AlertMutation { // Save creates the Alert in the database. func (ac *AlertCreate) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) ac.defaults() - if len(ac.hooks) == 0 { - if err = ac.check(); err != nil { - return nil, err - } - node, err = ac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ac.check(); err != nil { - return nil, err - } - ac.mutation = mutation - if node, err = ac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ac.hooks) - 1; i >= 0; i-- { - if ac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ac.sqlSave, ac.mutation, ac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -515,6 +487,12 @@ func (ac *AlertCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (ac *AlertCreate) check() error { + if _, ok := ac.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Alert.created_at"`)} + } + if _, ok := ac.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Alert.updated_at"`)} + } if _, ok := ac.mutation.Scenario(); !ok { return &ValidationError{Name: "scenario", err: errors.New(`ent: missing required field "Alert.scenario"`)} } @@ -525,6 +503,9 @@ func (ac *AlertCreate) check() error { } func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { + if err := ac.check(); err != nil { + return nil, err + } _node, _spec := ac.createSpec() if err := sqlgraph.CreateNode(ctx, ac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -534,204 +515,112 @@ func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ac.mutation.id = &_node.ID + ac.mutation.done = true return _node, nil } func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { var ( _node = &Alert{config: ac.config} - _spec = &sqlgraph.CreateSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) ) if value, ok := ac.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(alert.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := ac.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := ac.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) + _spec.SetField(alert.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := ac.mutation.BucketId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldBucketId, field.TypeString, value) _node.BucketId = value } if value, ok := ac.mutation.Message(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.SetField(alert.FieldMessage, field.TypeString, value) _node.Message = value } if value, ok := ac.mutation.EventsCount(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.SetField(alert.FieldEventsCount, field.TypeInt32, value) _node.EventsCount = value } if value, ok := ac.mutation.StartedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.SetField(alert.FieldStartedAt, field.TypeTime, value) _node.StartedAt = value } if value, ok := ac.mutation.StoppedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.SetField(alert.FieldStoppedAt, field.TypeTime, value) _node.StoppedAt = value } if value, ok := ac.mutation.SourceIp(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.SetField(alert.FieldSourceIp, field.TypeString, value) _node.SourceIp = value } if value, ok := ac.mutation.SourceRange(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.SetField(alert.FieldSourceRange, field.TypeString, value) _node.SourceRange = value } if value, ok := ac.mutation.SourceAsNumber(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value) _node.SourceAsNumber = value } if value, ok := ac.mutation.SourceAsName(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.SetField(alert.FieldSourceAsName, field.TypeString, value) _node.SourceAsName = value } if value, ok := ac.mutation.SourceCountry(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.SetField(alert.FieldSourceCountry, field.TypeString, value) _node.SourceCountry = value } if value, ok := ac.mutation.SourceLatitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value) _node.SourceLatitude = value } if value, ok := ac.mutation.SourceLongitude(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value) _node.SourceLongitude = value } if value, ok := ac.mutation.SourceScope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.SetField(alert.FieldSourceScope, field.TypeString, value) _node.SourceScope = value } if value, ok := ac.mutation.SourceValue(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.SetField(alert.FieldSourceValue, field.TypeString, value) _node.SourceValue = value } if value, ok := ac.mutation.Capacity(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.SetField(alert.FieldCapacity, field.TypeInt32, value) _node.Capacity = value } if value, ok := ac.mutation.LeakSpeed(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.SetField(alert.FieldLeakSpeed, field.TypeString, value) _node.LeakSpeed = value } if value, ok := ac.mutation.ScenarioVersion(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.SetField(alert.FieldScenarioVersion, field.TypeString, value) _node.ScenarioVersion = value } if value, ok := ac.mutation.ScenarioHash(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.SetField(alert.FieldScenarioHash, field.TypeString, value) _node.ScenarioHash = value } if value, ok := ac.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) + _spec.SetField(alert.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := ac.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.SetField(alert.FieldUUID, field.TypeString, value) _node.UUID = value } + if value, ok := ac.mutation.Remediation(); ok { + _spec.SetField(alert.FieldRemediation, field.TypeBool, value) + _node.Remediation = value + } if nodes := ac.mutation.OwnerIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -740,10 +629,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -760,10 +646,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -779,10 +662,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -798,10 +678,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -815,11 +692,15 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) { // AlertCreateBulk is the builder for creating many Alert entities in bulk. type AlertCreateBulk struct { config + err error builders []*AlertCreate } // Save creates the Alert entities in the database. func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { + if acb.err != nil { + return nil, acb.err + } specs := make([]*sqlgraph.CreateSpec, len(acb.builders)) nodes := make([]*Alert, len(acb.builders)) mutators := make([]Mutator, len(acb.builders)) @@ -836,8 +717,8 @@ func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, acb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/alert_delete.go b/pkg/database/ent/alert_delete.go index 014bcc2e0c6..15b3a4c822a 100644 --- a/pkg/database/ent/alert_delete.go +++ b/pkg/database/ent/alert_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ad *AlertDelete) Where(ps ...predicate.Alert) *AlertDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ad *AlertDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ad.hooks) == 0 { - affected, err = ad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ad.mutation = mutation - affected, err = ad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ad.hooks) - 1; i >= 0; i-- { - if ad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ad.sqlExec, ad.mutation, ad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ad *AlertDelete) ExecX(ctx context.Context) int { } func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := ad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type AlertDeleteOne struct { ad *AlertDelete } +// Where appends a list predicates to the AlertDelete builder. +func (ado *AlertDeleteOne) Where(ps ...predicate.Alert) *AlertDeleteOne { + ado.ad.mutation.Where(ps...) + return ado +} + // Exec executes the deletion query. func (ado *AlertDeleteOne) Exec(ctx context.Context) error { n, err := ado.ad.Exec(ctx) @@ -111,5 +82,7 @@ func (ado *AlertDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ado *AlertDeleteOne) ExecX(ctx context.Context) { - ado.ad.ExecX(ctx) + if err := ado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/alert_query.go b/pkg/database/ent/alert_query.go index 68789196d24..7eddb6ce024 100644 --- a/pkg/database/ent/alert_query.go +++ b/pkg/database/ent/alert_query.go @@ -22,11 +22,9 @@ import ( // AlertQuery is the builder for querying Alert entities. type AlertQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []alert.OrderOption + inters []Interceptor predicates []predicate.Alert withOwner *MachineQuery withDecisions *DecisionQuery @@ -44,34 +42,34 @@ func (aq *AlertQuery) Where(ps ...predicate.Alert) *AlertQuery { return aq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (aq *AlertQuery) Limit(limit int) *AlertQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } -// Offset adds an offset step to the query. +// Offset to start from. func (aq *AlertQuery) Offset(offset int) *AlertQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *AlertQuery) Unique(unique bool) *AlertQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } -// Order adds an order step to the query. -func (aq *AlertQuery) Order(o ...OrderFunc) *AlertQuery { +// Order specifies how the records should be ordered. +func (aq *AlertQuery) Order(o ...alert.OrderOption) *AlertQuery { aq.order = append(aq.order, o...) return aq } // QueryOwner chains the current query on the "owner" edge. func (aq *AlertQuery) QueryOwner() *MachineQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +91,7 @@ func (aq *AlertQuery) QueryOwner() *MachineQuery { // QueryDecisions chains the current query on the "decisions" edge. func (aq *AlertQuery) QueryDecisions() *DecisionQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +113,7 @@ func (aq *AlertQuery) QueryDecisions() *DecisionQuery { // QueryEvents chains the current query on the "events" edge. func (aq *AlertQuery) QueryEvents() *EventQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +135,7 @@ func (aq *AlertQuery) QueryEvents() *EventQuery { // QueryMetas chains the current query on the "metas" edge. func (aq *AlertQuery) QueryMetas() *MetaQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -160,7 +158,7 @@ func (aq *AlertQuery) QueryMetas() *MetaQuery { // First returns the first Alert entity from the query. // Returns a *NotFoundError when no Alert was found. func (aq *AlertQuery) First(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(1).All(ctx) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -183,7 +181,7 @@ func (aq *AlertQuery) FirstX(ctx context.Context) *Alert { // Returns a *NotFoundError when no Alert ID was found. func (aq *AlertQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(1).IDs(ctx); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -206,7 +204,7 @@ func (aq *AlertQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Alert entity is found. // Returns a *NotFoundError when no Alert entities are found. func (aq *AlertQuery) Only(ctx context.Context) (*Alert, error) { - nodes, err := aq.Limit(2).All(ctx) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -234,7 +232,7 @@ func (aq *AlertQuery) OnlyX(ctx context.Context) *Alert { // Returns a *NotFoundError when no entities are found. func (aq *AlertQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(2).IDs(ctx); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -259,10 +257,12 @@ func (aq *AlertQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Alerts. func (aq *AlertQuery) All(ctx context.Context) ([]*Alert, error) { + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } - return aq.sqlAll(ctx) + qr := querierAll[[]*Alert, *AlertQuery]() + return withInterceptors[[]*Alert](ctx, aq, qr, aq.inters) } // AllX is like All, but panics if an error occurs. @@ -275,9 +275,12 @@ func (aq *AlertQuery) AllX(ctx context.Context) []*Alert { } // IDs executes the query and returns a list of Alert IDs. -func (aq *AlertQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { +func (aq *AlertQuery) IDs(ctx context.Context) (ids []int, err error) { + if aq.ctx.Unique == nil && aq.path != nil { + aq.Unique(true) + } + ctx = setContextOp(ctx, aq.ctx, "IDs") + if err = aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -294,10 +297,11 @@ func (aq *AlertQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (aq *AlertQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } - return aq.sqlCount(ctx) + return withInterceptors[int](ctx, aq, querierCount[*AlertQuery](), aq.inters) } // CountX is like Count, but panics if an error occurs. @@ -311,10 +315,15 @@ func (aq *AlertQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *AlertQuery) Exist(ctx context.Context) (bool, error) { - if err := aq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, aq.ctx, "Exist") + switch _, err := aq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return aq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -334,25 +343,24 @@ func (aq *AlertQuery) Clone() *AlertQuery { } return &AlertQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, - order: append([]OrderFunc{}, aq.order...), + ctx: aq.ctx.Clone(), + order: append([]alert.OrderOption{}, aq.order...), + inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Alert{}, aq.predicates...), withOwner: aq.withOwner.Clone(), withDecisions: aq.withDecisions.Clone(), withEvents: aq.withEvents.Clone(), withMetas: aq.withMetas.Clone(), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { - query := &MachineQuery{config: aq.config} + query := (&MachineClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -363,7 +371,7 @@ func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery { // WithDecisions tells the query-builder to eager-load the nodes that are connected to // the "decisions" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { - query := &DecisionQuery{config: aq.config} + query := (&DecisionClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,7 +382,7 @@ func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery { // WithEvents tells the query-builder to eager-load the nodes that are connected to // the "events" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { - query := &EventQuery{config: aq.config} + query := (&EventClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -385,7 +393,7 @@ func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery { // WithMetas tells the query-builder to eager-load the nodes that are connected to // the "metas" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { - query := &MetaQuery{config: aq.config} + query := (&MetaClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -408,16 +416,11 @@ func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { - grbuild := &AlertGroupBy{config: aq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := aq.prepareQuery(ctx); err != nil { - return nil, err - } - return aq.sqlQuery(ctx), nil - } + aq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AlertGroupBy{build: aq} + grbuild.flds = &aq.ctx.Fields grbuild.label = alert.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -434,15 +437,30 @@ func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy { // Select(alert.FieldCreatedAt). // Scan(ctx, &v) func (aq *AlertQuery) Select(fields ...string) *AlertSelect { - aq.fields = append(aq.fields, fields...) - selbuild := &AlertSelect{AlertQuery: aq} - selbuild.label = alert.Label - selbuild.flds, selbuild.scan = &aq.fields, selbuild.Scan - return selbuild + aq.ctx.Fields = append(aq.ctx.Fields, fields...) + sbuild := &AlertSelect{AlertQuery: aq} + sbuild.label = alert.Label + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AlertSelect configured with the given aggregations. +func (aq *AlertQuery) Aggregate(fns ...AggregateFunc) *AlertSelect { + return aq.Select().Aggregate(fns...) } func (aq *AlertQuery) prepareQuery(ctx context.Context) error { - for _, f := range aq.fields { + for _, inter := range aq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, aq); err != nil { + return err + } + } + } + for _, f := range aq.ctx.Fields { if !alert.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -536,6 +554,9 @@ func (aq *AlertQuery) loadOwner(ctx context.Context, query *MachineQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(machine.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -562,8 +583,11 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(decision.FieldAlertDecisions) + } query.Where(predicate.Decision(func(s *sql.Selector) { - s.Where(sql.InValues(alert.DecisionsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.DecisionsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -573,7 +597,7 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n fk := n.AlertDecisions node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -589,8 +613,11 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(event.FieldAlertEvents) + } query.Where(predicate.Event(func(s *sql.Selector) { - s.Where(sql.InValues(alert.EventsColumn, fks...)) + s.Where(sql.InValues(s.C(alert.EventsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -600,7 +627,7 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [ fk := n.AlertEvents node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_events" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_events" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -616,8 +643,11 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* init(nodes[i]) } } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(meta.FieldAlertMetas) + } query.Where(predicate.Meta(func(s *sql.Selector) { - s.Where(sql.InValues(alert.MetasColumn, fks...)) + s.Where(sql.InValues(s.C(alert.MetasColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -627,7 +657,7 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* fk := n.AlertMetas node, ok := nodeids[fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "alert_metas" returned %v for node %v`, fk, n.ID) } assign(node, n) } @@ -636,41 +666,22 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []* func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) { _spec := aq.querySpec() - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } -func (aq *AlertQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := aq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - From: aq.sql, - Unique: true, - } - if unique := aq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) + _spec.From = aq.sql + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if aq.path != nil { + _spec.Unique = true } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, alert.FieldID) for i := range fields { @@ -686,10 +697,10 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -705,7 +716,7 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec { func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(alert.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = alert.Columns } @@ -714,7 +725,7 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, p := range aq.predicates { @@ -723,12 +734,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -736,13 +747,8 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector { // AlertGroupBy is the group-by builder for Alert entities. type AlertGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *AlertQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -751,74 +757,77 @@ func (agb *AlertGroupBy) Aggregate(fns ...AggregateFunc) *AlertGroupBy { return agb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (agb *AlertGroupBy) Scan(ctx context.Context, v any) error { - query, err := agb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") + if err := agb.build.prepareQuery(ctx); err != nil { return err } - agb.sql = query - return agb.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertGroupBy](ctx, agb.build, agb, agb.build.inters, v) } -func (agb *AlertGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range agb.fields { - if !alert.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (agb *AlertGroupBy) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(agb.fns)) + for _, fn := range agb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*agb.flds)+len(agb.fns)) + for _, f := range *agb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := agb.sqlQuery() + selector.GroupBy(selector.Columns(*agb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := agb.driver.Query(ctx, query, args, rows); err != nil { + if err := agb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (agb *AlertGroupBy) sqlQuery() *sql.Selector { - selector := agb.sql.Select() - aggregation := make([]string, 0, len(agb.fns)) - for _, fn := range agb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(agb.fields)+len(agb.fns)) - for _, f := range agb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(agb.fields...)...) -} - // AlertSelect is the builder for selecting fields of Alert entities. type AlertSelect struct { *AlertQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (as *AlertSelect) Aggregate(fns ...AggregateFunc) *AlertSelect { + as.fns = append(as.fns, fns...) + return as } // Scan applies the selector query and scans the result into the given value. func (as *AlertSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } - as.sql = as.AlertQuery.sqlQuery(ctx) - return as.sqlScan(ctx, v) + return scanWithInterceptors[*AlertQuery, *AlertSelect](ctx, as.AlertQuery, as, as.inters, v) } -func (as *AlertSelect) sqlScan(ctx context.Context, v any) error { +func (as *AlertSelect) sqlScan(ctx context.Context, root *AlertQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(as.fns)) + for _, fn := range as.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*as.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := as.sql.Query() + query, args := selector.Query() if err := as.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/alert_update.go b/pkg/database/ent/alert_update.go index aaa12ef20a3..5f0e01ac09f 100644 --- a/pkg/database/ent/alert_update.go +++ b/pkg/database/ent/alert_update.go @@ -32,458 +32,12 @@ func (au *AlertUpdate) Where(ps ...predicate.Alert) *AlertUpdate { return au } -// SetCreatedAt sets the "created_at" field. -func (au *AlertUpdate) SetCreatedAt(t time.Time) *AlertUpdate { - au.mutation.SetCreatedAt(t) - return au -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (au *AlertUpdate) ClearCreatedAt() *AlertUpdate { - au.mutation.ClearCreatedAt() - return au -} - // SetUpdatedAt sets the "updated_at" field. func (au *AlertUpdate) SetUpdatedAt(t time.Time) *AlertUpdate { au.mutation.SetUpdatedAt(t) return au } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (au *AlertUpdate) ClearUpdatedAt() *AlertUpdate { - au.mutation.ClearUpdatedAt() - return au -} - -// SetScenario sets the "scenario" field. -func (au *AlertUpdate) SetScenario(s string) *AlertUpdate { - au.mutation.SetScenario(s) - return au -} - -// SetBucketId sets the "bucketId" field. -func (au *AlertUpdate) SetBucketId(s string) *AlertUpdate { - au.mutation.SetBucketId(s) - return au -} - -// SetNillableBucketId sets the "bucketId" field if the given value is not nil. -func (au *AlertUpdate) SetNillableBucketId(s *string) *AlertUpdate { - if s != nil { - au.SetBucketId(*s) - } - return au -} - -// ClearBucketId clears the value of the "bucketId" field. -func (au *AlertUpdate) ClearBucketId() *AlertUpdate { - au.mutation.ClearBucketId() - return au -} - -// SetMessage sets the "message" field. -func (au *AlertUpdate) SetMessage(s string) *AlertUpdate { - au.mutation.SetMessage(s) - return au -} - -// SetNillableMessage sets the "message" field if the given value is not nil. -func (au *AlertUpdate) SetNillableMessage(s *string) *AlertUpdate { - if s != nil { - au.SetMessage(*s) - } - return au -} - -// ClearMessage clears the value of the "message" field. -func (au *AlertUpdate) ClearMessage() *AlertUpdate { - au.mutation.ClearMessage() - return au -} - -// SetEventsCount sets the "eventsCount" field. -func (au *AlertUpdate) SetEventsCount(i int32) *AlertUpdate { - au.mutation.ResetEventsCount() - au.mutation.SetEventsCount(i) - return au -} - -// SetNillableEventsCount sets the "eventsCount" field if the given value is not nil. -func (au *AlertUpdate) SetNillableEventsCount(i *int32) *AlertUpdate { - if i != nil { - au.SetEventsCount(*i) - } - return au -} - -// AddEventsCount adds i to the "eventsCount" field. -func (au *AlertUpdate) AddEventsCount(i int32) *AlertUpdate { - au.mutation.AddEventsCount(i) - return au -} - -// ClearEventsCount clears the value of the "eventsCount" field. -func (au *AlertUpdate) ClearEventsCount() *AlertUpdate { - au.mutation.ClearEventsCount() - return au -} - -// SetStartedAt sets the "startedAt" field. -func (au *AlertUpdate) SetStartedAt(t time.Time) *AlertUpdate { - au.mutation.SetStartedAt(t) - return au -} - -// SetNillableStartedAt sets the "startedAt" field if the given value is not nil. -func (au *AlertUpdate) SetNillableStartedAt(t *time.Time) *AlertUpdate { - if t != nil { - au.SetStartedAt(*t) - } - return au -} - -// ClearStartedAt clears the value of the "startedAt" field. -func (au *AlertUpdate) ClearStartedAt() *AlertUpdate { - au.mutation.ClearStartedAt() - return au -} - -// SetStoppedAt sets the "stoppedAt" field. -func (au *AlertUpdate) SetStoppedAt(t time.Time) *AlertUpdate { - au.mutation.SetStoppedAt(t) - return au -} - -// SetNillableStoppedAt sets the "stoppedAt" field if the given value is not nil. -func (au *AlertUpdate) SetNillableStoppedAt(t *time.Time) *AlertUpdate { - if t != nil { - au.SetStoppedAt(*t) - } - return au -} - -// ClearStoppedAt clears the value of the "stoppedAt" field. -func (au *AlertUpdate) ClearStoppedAt() *AlertUpdate { - au.mutation.ClearStoppedAt() - return au -} - -// SetSourceIp sets the "sourceIp" field. -func (au *AlertUpdate) SetSourceIp(s string) *AlertUpdate { - au.mutation.SetSourceIp(s) - return au -} - -// SetNillableSourceIp sets the "sourceIp" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceIp(s *string) *AlertUpdate { - if s != nil { - au.SetSourceIp(*s) - } - return au -} - -// ClearSourceIp clears the value of the "sourceIp" field. -func (au *AlertUpdate) ClearSourceIp() *AlertUpdate { - au.mutation.ClearSourceIp() - return au -} - -// SetSourceRange sets the "sourceRange" field. -func (au *AlertUpdate) SetSourceRange(s string) *AlertUpdate { - au.mutation.SetSourceRange(s) - return au -} - -// SetNillableSourceRange sets the "sourceRange" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceRange(s *string) *AlertUpdate { - if s != nil { - au.SetSourceRange(*s) - } - return au -} - -// ClearSourceRange clears the value of the "sourceRange" field. -func (au *AlertUpdate) ClearSourceRange() *AlertUpdate { - au.mutation.ClearSourceRange() - return au -} - -// SetSourceAsNumber sets the "sourceAsNumber" field. -func (au *AlertUpdate) SetSourceAsNumber(s string) *AlertUpdate { - au.mutation.SetSourceAsNumber(s) - return au -} - -// SetNillableSourceAsNumber sets the "sourceAsNumber" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceAsNumber(s *string) *AlertUpdate { - if s != nil { - au.SetSourceAsNumber(*s) - } - return au -} - -// ClearSourceAsNumber clears the value of the "sourceAsNumber" field. -func (au *AlertUpdate) ClearSourceAsNumber() *AlertUpdate { - au.mutation.ClearSourceAsNumber() - return au -} - -// SetSourceAsName sets the "sourceAsName" field. -func (au *AlertUpdate) SetSourceAsName(s string) *AlertUpdate { - au.mutation.SetSourceAsName(s) - return au -} - -// SetNillableSourceAsName sets the "sourceAsName" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceAsName(s *string) *AlertUpdate { - if s != nil { - au.SetSourceAsName(*s) - } - return au -} - -// ClearSourceAsName clears the value of the "sourceAsName" field. -func (au *AlertUpdate) ClearSourceAsName() *AlertUpdate { - au.mutation.ClearSourceAsName() - return au -} - -// SetSourceCountry sets the "sourceCountry" field. -func (au *AlertUpdate) SetSourceCountry(s string) *AlertUpdate { - au.mutation.SetSourceCountry(s) - return au -} - -// SetNillableSourceCountry sets the "sourceCountry" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceCountry(s *string) *AlertUpdate { - if s != nil { - au.SetSourceCountry(*s) - } - return au -} - -// ClearSourceCountry clears the value of the "sourceCountry" field. -func (au *AlertUpdate) ClearSourceCountry() *AlertUpdate { - au.mutation.ClearSourceCountry() - return au -} - -// SetSourceLatitude sets the "sourceLatitude" field. -func (au *AlertUpdate) SetSourceLatitude(f float32) *AlertUpdate { - au.mutation.ResetSourceLatitude() - au.mutation.SetSourceLatitude(f) - return au -} - -// SetNillableSourceLatitude sets the "sourceLatitude" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceLatitude(f *float32) *AlertUpdate { - if f != nil { - au.SetSourceLatitude(*f) - } - return au -} - -// AddSourceLatitude adds f to the "sourceLatitude" field. -func (au *AlertUpdate) AddSourceLatitude(f float32) *AlertUpdate { - au.mutation.AddSourceLatitude(f) - return au -} - -// ClearSourceLatitude clears the value of the "sourceLatitude" field. -func (au *AlertUpdate) ClearSourceLatitude() *AlertUpdate { - au.mutation.ClearSourceLatitude() - return au -} - -// SetSourceLongitude sets the "sourceLongitude" field. -func (au *AlertUpdate) SetSourceLongitude(f float32) *AlertUpdate { - au.mutation.ResetSourceLongitude() - au.mutation.SetSourceLongitude(f) - return au -} - -// SetNillableSourceLongitude sets the "sourceLongitude" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceLongitude(f *float32) *AlertUpdate { - if f != nil { - au.SetSourceLongitude(*f) - } - return au -} - -// AddSourceLongitude adds f to the "sourceLongitude" field. -func (au *AlertUpdate) AddSourceLongitude(f float32) *AlertUpdate { - au.mutation.AddSourceLongitude(f) - return au -} - -// ClearSourceLongitude clears the value of the "sourceLongitude" field. -func (au *AlertUpdate) ClearSourceLongitude() *AlertUpdate { - au.mutation.ClearSourceLongitude() - return au -} - -// SetSourceScope sets the "sourceScope" field. -func (au *AlertUpdate) SetSourceScope(s string) *AlertUpdate { - au.mutation.SetSourceScope(s) - return au -} - -// SetNillableSourceScope sets the "sourceScope" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceScope(s *string) *AlertUpdate { - if s != nil { - au.SetSourceScope(*s) - } - return au -} - -// ClearSourceScope clears the value of the "sourceScope" field. -func (au *AlertUpdate) ClearSourceScope() *AlertUpdate { - au.mutation.ClearSourceScope() - return au -} - -// SetSourceValue sets the "sourceValue" field. -func (au *AlertUpdate) SetSourceValue(s string) *AlertUpdate { - au.mutation.SetSourceValue(s) - return au -} - -// SetNillableSourceValue sets the "sourceValue" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSourceValue(s *string) *AlertUpdate { - if s != nil { - au.SetSourceValue(*s) - } - return au -} - -// ClearSourceValue clears the value of the "sourceValue" field. -func (au *AlertUpdate) ClearSourceValue() *AlertUpdate { - au.mutation.ClearSourceValue() - return au -} - -// SetCapacity sets the "capacity" field. -func (au *AlertUpdate) SetCapacity(i int32) *AlertUpdate { - au.mutation.ResetCapacity() - au.mutation.SetCapacity(i) - return au -} - -// SetNillableCapacity sets the "capacity" field if the given value is not nil. -func (au *AlertUpdate) SetNillableCapacity(i *int32) *AlertUpdate { - if i != nil { - au.SetCapacity(*i) - } - return au -} - -// AddCapacity adds i to the "capacity" field. -func (au *AlertUpdate) AddCapacity(i int32) *AlertUpdate { - au.mutation.AddCapacity(i) - return au -} - -// ClearCapacity clears the value of the "capacity" field. -func (au *AlertUpdate) ClearCapacity() *AlertUpdate { - au.mutation.ClearCapacity() - return au -} - -// SetLeakSpeed sets the "leakSpeed" field. -func (au *AlertUpdate) SetLeakSpeed(s string) *AlertUpdate { - au.mutation.SetLeakSpeed(s) - return au -} - -// SetNillableLeakSpeed sets the "leakSpeed" field if the given value is not nil. -func (au *AlertUpdate) SetNillableLeakSpeed(s *string) *AlertUpdate { - if s != nil { - au.SetLeakSpeed(*s) - } - return au -} - -// ClearLeakSpeed clears the value of the "leakSpeed" field. -func (au *AlertUpdate) ClearLeakSpeed() *AlertUpdate { - au.mutation.ClearLeakSpeed() - return au -} - -// SetScenarioVersion sets the "scenarioVersion" field. -func (au *AlertUpdate) SetScenarioVersion(s string) *AlertUpdate { - au.mutation.SetScenarioVersion(s) - return au -} - -// SetNillableScenarioVersion sets the "scenarioVersion" field if the given value is not nil. -func (au *AlertUpdate) SetNillableScenarioVersion(s *string) *AlertUpdate { - if s != nil { - au.SetScenarioVersion(*s) - } - return au -} - -// ClearScenarioVersion clears the value of the "scenarioVersion" field. -func (au *AlertUpdate) ClearScenarioVersion() *AlertUpdate { - au.mutation.ClearScenarioVersion() - return au -} - -// SetScenarioHash sets the "scenarioHash" field. -func (au *AlertUpdate) SetScenarioHash(s string) *AlertUpdate { - au.mutation.SetScenarioHash(s) - return au -} - -// SetNillableScenarioHash sets the "scenarioHash" field if the given value is not nil. -func (au *AlertUpdate) SetNillableScenarioHash(s *string) *AlertUpdate { - if s != nil { - au.SetScenarioHash(*s) - } - return au -} - -// ClearScenarioHash clears the value of the "scenarioHash" field. -func (au *AlertUpdate) ClearScenarioHash() *AlertUpdate { - au.mutation.ClearScenarioHash() - return au -} - -// SetSimulated sets the "simulated" field. -func (au *AlertUpdate) SetSimulated(b bool) *AlertUpdate { - au.mutation.SetSimulated(b) - return au -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (au *AlertUpdate) SetNillableSimulated(b *bool) *AlertUpdate { - if b != nil { - au.SetSimulated(*b) - } - return au -} - -// SetUUID sets the "uuid" field. -func (au *AlertUpdate) SetUUID(s string) *AlertUpdate { - au.mutation.SetUUID(s) - return au -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (au *AlertUpdate) SetNillableUUID(s *string) *AlertUpdate { - if s != nil { - au.SetUUID(*s) - } - return au -} - -// ClearUUID clears the value of the "uuid" field. -func (au *AlertUpdate) ClearUUID() *AlertUpdate { - au.mutation.ClearUUID() - return au -} - // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (au *AlertUpdate) SetOwnerID(id int) *AlertUpdate { au.mutation.SetOwnerID(id) @@ -624,35 +178,8 @@ func (au *AlertUpdate) RemoveMetas(m ...*Meta) *AlertUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (au *AlertUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) au.defaults() - if len(au.hooks) == 0 { - affected, err = au.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - au.mutation = mutation - affected, err = au.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(au.hooks) - 1; i >= 0; i-- { - if au.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = au.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, au.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, au.sqlSave, au.mutation, au.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -679,27 +206,14 @@ func (au *AlertUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (au *AlertUpdate) defaults() { - if _, ok := au.mutation.CreatedAt(); !ok && !au.mutation.CreatedAtCleared() { - v := alert.UpdateDefaultCreatedAt() - au.mutation.SetCreatedAt(v) - } - if _, ok := au.mutation.UpdatedAt(); !ok && !au.mutation.UpdatedAtCleared() { + if _, ok := au.mutation.UpdatedAt(); !ok { v := alert.UpdateDefaultUpdatedAt() au.mutation.SetUpdatedAt(v) } } func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) if ps := au.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -707,320 +221,68 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := au.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - } - if au.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) - } if value, ok := au.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - } - if au.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) - } - if value, ok := au.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) - } - if value, ok := au.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if au.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) - } - if value, ok := au.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if au.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) - } - if value, ok := au.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) - } - if value, ok := au.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if au.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) - } - if value, ok := au.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if au.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) - } - if value, ok := au.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if au.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) - } - if value, ok := au.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if au.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) - } - if value, ok := au.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if au.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) - } - if value, ok := au.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if au.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) - } - if value, ok := au.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if au.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) - } - if value, ok := au.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if au.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) - } - if value, ok := au.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := au.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if au.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := au.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := au.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if au.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := au.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if au.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) - } - if value, ok := au.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if au.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) - } - if value, ok := au.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) - } - if value, ok := au.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if au.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) - } - if value, ok := au.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if au.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) - } - if value, ok := au.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if au.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) - } - if value, ok := au.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if au.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) - } - if value, ok := au.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) - } - if value, ok := au.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if au.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) + } + if au.mutation.RemediationCleared() { + _spec.ClearField(alert.FieldRemediation, field.TypeBool) } if au.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1030,10 +292,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1046,10 +305,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1065,10 +321,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1081,10 +334,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1100,10 +350,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1119,10 +366,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1135,10 +379,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1154,10 +395,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1173,10 +411,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1189,10 +424,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1208,10 +440,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1227,6 +456,7 @@ func (au *AlertUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + au.mutation.done = true return n, nil } @@ -1238,458 +468,12 @@ type AlertUpdateOne struct { mutation *AlertMutation } -// SetCreatedAt sets the "created_at" field. -func (auo *AlertUpdateOne) SetCreatedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetCreatedAt(t) - return auo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (auo *AlertUpdateOne) ClearCreatedAt() *AlertUpdateOne { - auo.mutation.ClearCreatedAt() - return auo -} - // SetUpdatedAt sets the "updated_at" field. func (auo *AlertUpdateOne) SetUpdatedAt(t time.Time) *AlertUpdateOne { auo.mutation.SetUpdatedAt(t) return auo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (auo *AlertUpdateOne) ClearUpdatedAt() *AlertUpdateOne { - auo.mutation.ClearUpdatedAt() - return auo -} - -// SetScenario sets the "scenario" field. -func (auo *AlertUpdateOne) SetScenario(s string) *AlertUpdateOne { - auo.mutation.SetScenario(s) - return auo -} - -// SetBucketId sets the "bucketId" field. -func (auo *AlertUpdateOne) SetBucketId(s string) *AlertUpdateOne { - auo.mutation.SetBucketId(s) - return auo -} - -// SetNillableBucketId sets the "bucketId" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableBucketId(s *string) *AlertUpdateOne { - if s != nil { - auo.SetBucketId(*s) - } - return auo -} - -// ClearBucketId clears the value of the "bucketId" field. -func (auo *AlertUpdateOne) ClearBucketId() *AlertUpdateOne { - auo.mutation.ClearBucketId() - return auo -} - -// SetMessage sets the "message" field. -func (auo *AlertUpdateOne) SetMessage(s string) *AlertUpdateOne { - auo.mutation.SetMessage(s) - return auo -} - -// SetNillableMessage sets the "message" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableMessage(s *string) *AlertUpdateOne { - if s != nil { - auo.SetMessage(*s) - } - return auo -} - -// ClearMessage clears the value of the "message" field. -func (auo *AlertUpdateOne) ClearMessage() *AlertUpdateOne { - auo.mutation.ClearMessage() - return auo -} - -// SetEventsCount sets the "eventsCount" field. -func (auo *AlertUpdateOne) SetEventsCount(i int32) *AlertUpdateOne { - auo.mutation.ResetEventsCount() - auo.mutation.SetEventsCount(i) - return auo -} - -// SetNillableEventsCount sets the "eventsCount" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableEventsCount(i *int32) *AlertUpdateOne { - if i != nil { - auo.SetEventsCount(*i) - } - return auo -} - -// AddEventsCount adds i to the "eventsCount" field. -func (auo *AlertUpdateOne) AddEventsCount(i int32) *AlertUpdateOne { - auo.mutation.AddEventsCount(i) - return auo -} - -// ClearEventsCount clears the value of the "eventsCount" field. -func (auo *AlertUpdateOne) ClearEventsCount() *AlertUpdateOne { - auo.mutation.ClearEventsCount() - return auo -} - -// SetStartedAt sets the "startedAt" field. -func (auo *AlertUpdateOne) SetStartedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetStartedAt(t) - return auo -} - -// SetNillableStartedAt sets the "startedAt" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableStartedAt(t *time.Time) *AlertUpdateOne { - if t != nil { - auo.SetStartedAt(*t) - } - return auo -} - -// ClearStartedAt clears the value of the "startedAt" field. -func (auo *AlertUpdateOne) ClearStartedAt() *AlertUpdateOne { - auo.mutation.ClearStartedAt() - return auo -} - -// SetStoppedAt sets the "stoppedAt" field. -func (auo *AlertUpdateOne) SetStoppedAt(t time.Time) *AlertUpdateOne { - auo.mutation.SetStoppedAt(t) - return auo -} - -// SetNillableStoppedAt sets the "stoppedAt" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableStoppedAt(t *time.Time) *AlertUpdateOne { - if t != nil { - auo.SetStoppedAt(*t) - } - return auo -} - -// ClearStoppedAt clears the value of the "stoppedAt" field. -func (auo *AlertUpdateOne) ClearStoppedAt() *AlertUpdateOne { - auo.mutation.ClearStoppedAt() - return auo -} - -// SetSourceIp sets the "sourceIp" field. -func (auo *AlertUpdateOne) SetSourceIp(s string) *AlertUpdateOne { - auo.mutation.SetSourceIp(s) - return auo -} - -// SetNillableSourceIp sets the "sourceIp" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceIp(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceIp(*s) - } - return auo -} - -// ClearSourceIp clears the value of the "sourceIp" field. -func (auo *AlertUpdateOne) ClearSourceIp() *AlertUpdateOne { - auo.mutation.ClearSourceIp() - return auo -} - -// SetSourceRange sets the "sourceRange" field. -func (auo *AlertUpdateOne) SetSourceRange(s string) *AlertUpdateOne { - auo.mutation.SetSourceRange(s) - return auo -} - -// SetNillableSourceRange sets the "sourceRange" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceRange(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceRange(*s) - } - return auo -} - -// ClearSourceRange clears the value of the "sourceRange" field. -func (auo *AlertUpdateOne) ClearSourceRange() *AlertUpdateOne { - auo.mutation.ClearSourceRange() - return auo -} - -// SetSourceAsNumber sets the "sourceAsNumber" field. -func (auo *AlertUpdateOne) SetSourceAsNumber(s string) *AlertUpdateOne { - auo.mutation.SetSourceAsNumber(s) - return auo -} - -// SetNillableSourceAsNumber sets the "sourceAsNumber" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceAsNumber(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceAsNumber(*s) - } - return auo -} - -// ClearSourceAsNumber clears the value of the "sourceAsNumber" field. -func (auo *AlertUpdateOne) ClearSourceAsNumber() *AlertUpdateOne { - auo.mutation.ClearSourceAsNumber() - return auo -} - -// SetSourceAsName sets the "sourceAsName" field. -func (auo *AlertUpdateOne) SetSourceAsName(s string) *AlertUpdateOne { - auo.mutation.SetSourceAsName(s) - return auo -} - -// SetNillableSourceAsName sets the "sourceAsName" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceAsName(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceAsName(*s) - } - return auo -} - -// ClearSourceAsName clears the value of the "sourceAsName" field. -func (auo *AlertUpdateOne) ClearSourceAsName() *AlertUpdateOne { - auo.mutation.ClearSourceAsName() - return auo -} - -// SetSourceCountry sets the "sourceCountry" field. -func (auo *AlertUpdateOne) SetSourceCountry(s string) *AlertUpdateOne { - auo.mutation.SetSourceCountry(s) - return auo -} - -// SetNillableSourceCountry sets the "sourceCountry" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceCountry(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceCountry(*s) - } - return auo -} - -// ClearSourceCountry clears the value of the "sourceCountry" field. -func (auo *AlertUpdateOne) ClearSourceCountry() *AlertUpdateOne { - auo.mutation.ClearSourceCountry() - return auo -} - -// SetSourceLatitude sets the "sourceLatitude" field. -func (auo *AlertUpdateOne) SetSourceLatitude(f float32) *AlertUpdateOne { - auo.mutation.ResetSourceLatitude() - auo.mutation.SetSourceLatitude(f) - return auo -} - -// SetNillableSourceLatitude sets the "sourceLatitude" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceLatitude(f *float32) *AlertUpdateOne { - if f != nil { - auo.SetSourceLatitude(*f) - } - return auo -} - -// AddSourceLatitude adds f to the "sourceLatitude" field. -func (auo *AlertUpdateOne) AddSourceLatitude(f float32) *AlertUpdateOne { - auo.mutation.AddSourceLatitude(f) - return auo -} - -// ClearSourceLatitude clears the value of the "sourceLatitude" field. -func (auo *AlertUpdateOne) ClearSourceLatitude() *AlertUpdateOne { - auo.mutation.ClearSourceLatitude() - return auo -} - -// SetSourceLongitude sets the "sourceLongitude" field. -func (auo *AlertUpdateOne) SetSourceLongitude(f float32) *AlertUpdateOne { - auo.mutation.ResetSourceLongitude() - auo.mutation.SetSourceLongitude(f) - return auo -} - -// SetNillableSourceLongitude sets the "sourceLongitude" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceLongitude(f *float32) *AlertUpdateOne { - if f != nil { - auo.SetSourceLongitude(*f) - } - return auo -} - -// AddSourceLongitude adds f to the "sourceLongitude" field. -func (auo *AlertUpdateOne) AddSourceLongitude(f float32) *AlertUpdateOne { - auo.mutation.AddSourceLongitude(f) - return auo -} - -// ClearSourceLongitude clears the value of the "sourceLongitude" field. -func (auo *AlertUpdateOne) ClearSourceLongitude() *AlertUpdateOne { - auo.mutation.ClearSourceLongitude() - return auo -} - -// SetSourceScope sets the "sourceScope" field. -func (auo *AlertUpdateOne) SetSourceScope(s string) *AlertUpdateOne { - auo.mutation.SetSourceScope(s) - return auo -} - -// SetNillableSourceScope sets the "sourceScope" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceScope(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceScope(*s) - } - return auo -} - -// ClearSourceScope clears the value of the "sourceScope" field. -func (auo *AlertUpdateOne) ClearSourceScope() *AlertUpdateOne { - auo.mutation.ClearSourceScope() - return auo -} - -// SetSourceValue sets the "sourceValue" field. -func (auo *AlertUpdateOne) SetSourceValue(s string) *AlertUpdateOne { - auo.mutation.SetSourceValue(s) - return auo -} - -// SetNillableSourceValue sets the "sourceValue" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSourceValue(s *string) *AlertUpdateOne { - if s != nil { - auo.SetSourceValue(*s) - } - return auo -} - -// ClearSourceValue clears the value of the "sourceValue" field. -func (auo *AlertUpdateOne) ClearSourceValue() *AlertUpdateOne { - auo.mutation.ClearSourceValue() - return auo -} - -// SetCapacity sets the "capacity" field. -func (auo *AlertUpdateOne) SetCapacity(i int32) *AlertUpdateOne { - auo.mutation.ResetCapacity() - auo.mutation.SetCapacity(i) - return auo -} - -// SetNillableCapacity sets the "capacity" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableCapacity(i *int32) *AlertUpdateOne { - if i != nil { - auo.SetCapacity(*i) - } - return auo -} - -// AddCapacity adds i to the "capacity" field. -func (auo *AlertUpdateOne) AddCapacity(i int32) *AlertUpdateOne { - auo.mutation.AddCapacity(i) - return auo -} - -// ClearCapacity clears the value of the "capacity" field. -func (auo *AlertUpdateOne) ClearCapacity() *AlertUpdateOne { - auo.mutation.ClearCapacity() - return auo -} - -// SetLeakSpeed sets the "leakSpeed" field. -func (auo *AlertUpdateOne) SetLeakSpeed(s string) *AlertUpdateOne { - auo.mutation.SetLeakSpeed(s) - return auo -} - -// SetNillableLeakSpeed sets the "leakSpeed" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableLeakSpeed(s *string) *AlertUpdateOne { - if s != nil { - auo.SetLeakSpeed(*s) - } - return auo -} - -// ClearLeakSpeed clears the value of the "leakSpeed" field. -func (auo *AlertUpdateOne) ClearLeakSpeed() *AlertUpdateOne { - auo.mutation.ClearLeakSpeed() - return auo -} - -// SetScenarioVersion sets the "scenarioVersion" field. -func (auo *AlertUpdateOne) SetScenarioVersion(s string) *AlertUpdateOne { - auo.mutation.SetScenarioVersion(s) - return auo -} - -// SetNillableScenarioVersion sets the "scenarioVersion" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableScenarioVersion(s *string) *AlertUpdateOne { - if s != nil { - auo.SetScenarioVersion(*s) - } - return auo -} - -// ClearScenarioVersion clears the value of the "scenarioVersion" field. -func (auo *AlertUpdateOne) ClearScenarioVersion() *AlertUpdateOne { - auo.mutation.ClearScenarioVersion() - return auo -} - -// SetScenarioHash sets the "scenarioHash" field. -func (auo *AlertUpdateOne) SetScenarioHash(s string) *AlertUpdateOne { - auo.mutation.SetScenarioHash(s) - return auo -} - -// SetNillableScenarioHash sets the "scenarioHash" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableScenarioHash(s *string) *AlertUpdateOne { - if s != nil { - auo.SetScenarioHash(*s) - } - return auo -} - -// ClearScenarioHash clears the value of the "scenarioHash" field. -func (auo *AlertUpdateOne) ClearScenarioHash() *AlertUpdateOne { - auo.mutation.ClearScenarioHash() - return auo -} - -// SetSimulated sets the "simulated" field. -func (auo *AlertUpdateOne) SetSimulated(b bool) *AlertUpdateOne { - auo.mutation.SetSimulated(b) - return auo -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableSimulated(b *bool) *AlertUpdateOne { - if b != nil { - auo.SetSimulated(*b) - } - return auo -} - -// SetUUID sets the "uuid" field. -func (auo *AlertUpdateOne) SetUUID(s string) *AlertUpdateOne { - auo.mutation.SetUUID(s) - return auo -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (auo *AlertUpdateOne) SetNillableUUID(s *string) *AlertUpdateOne { - if s != nil { - auo.SetUUID(*s) - } - return auo -} - -// ClearUUID clears the value of the "uuid" field. -func (auo *AlertUpdateOne) ClearUUID() *AlertUpdateOne { - auo.mutation.ClearUUID() - return auo -} - // SetOwnerID sets the "owner" edge to the Machine entity by ID. func (auo *AlertUpdateOne) SetOwnerID(id int) *AlertUpdateOne { auo.mutation.SetOwnerID(id) @@ -1828,6 +612,12 @@ func (auo *AlertUpdateOne) RemoveMetas(m ...*Meta) *AlertUpdateOne { return auo.RemoveMetaIDs(ids...) } +// Where appends a list predicates to the AlertUpdate builder. +func (auo *AlertUpdateOne) Where(ps ...predicate.Alert) *AlertUpdateOne { + auo.mutation.Where(ps...) + return auo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOne { @@ -1837,41 +627,8 @@ func (auo *AlertUpdateOne) Select(field string, fields ...string) *AlertUpdateOn // Save executes the query and returns the updated Alert entity. func (auo *AlertUpdateOne) Save(ctx context.Context) (*Alert, error) { - var ( - err error - node *Alert - ) auo.defaults() - if len(auo.hooks) == 0 { - node, err = auo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - auo.mutation = mutation - node, err = auo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(auo.hooks) - 1; i >= 0; i-- { - if auo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = auo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, auo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Alert) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, auo.sqlSave, auo.mutation, auo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1898,27 +655,14 @@ func (auo *AlertUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (auo *AlertUpdateOne) defaults() { - if _, ok := auo.mutation.CreatedAt(); !ok && !auo.mutation.CreatedAtCleared() { - v := alert.UpdateDefaultCreatedAt() - auo.mutation.SetCreatedAt(v) - } - if _, ok := auo.mutation.UpdatedAt(); !ok && !auo.mutation.UpdatedAtCleared() { + if _, ok := auo.mutation.UpdatedAt(); !ok { v := alert.UpdateDefaultUpdatedAt() auo.mutation.SetUpdatedAt(v) } } func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: alert.Table, - Columns: alert.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt)) id, ok := auo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Alert.id" for update`)} @@ -1943,320 +687,68 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } } } - if value, ok := auo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldCreatedAt, - }) - } - if auo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldCreatedAt, - }) - } if value, ok := auo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldUpdatedAt, - }) - } - if auo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldUpdatedAt, - }) - } - if value, ok := auo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenario, - }) - } - if value, ok := auo.mutation.BucketId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldBucketId, - }) + _spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value) } if auo.mutation.BucketIdCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldBucketId, - }) - } - if value, ok := auo.mutation.Message(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldMessage, - }) + _spec.ClearField(alert.FieldBucketId, field.TypeString) } if auo.mutation.MessageCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldMessage, - }) - } - if value, ok := auo.mutation.EventsCount(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) - } - if value, ok := auo.mutation.AddedEventsCount(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldEventsCount, - }) + _spec.ClearField(alert.FieldMessage, field.TypeString) } if auo.mutation.EventsCountCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldEventsCount, - }) - } - if value, ok := auo.mutation.StartedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStartedAt, - }) + _spec.ClearField(alert.FieldEventsCount, field.TypeInt32) } if auo.mutation.StartedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStartedAt, - }) - } - if value, ok := auo.mutation.StoppedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: alert.FieldStoppedAt, - }) + _spec.ClearField(alert.FieldStartedAt, field.TypeTime) } if auo.mutation.StoppedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: alert.FieldStoppedAt, - }) - } - if value, ok := auo.mutation.SourceIp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceIp, - }) + _spec.ClearField(alert.FieldStoppedAt, field.TypeTime) } if auo.mutation.SourceIpCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceIp, - }) - } - if value, ok := auo.mutation.SourceRange(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceRange, - }) + _spec.ClearField(alert.FieldSourceIp, field.TypeString) } if auo.mutation.SourceRangeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceRange, - }) - } - if value, ok := auo.mutation.SourceAsNumber(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsNumber, - }) + _spec.ClearField(alert.FieldSourceRange, field.TypeString) } if auo.mutation.SourceAsNumberCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsNumber, - }) - } - if value, ok := auo.mutation.SourceAsName(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceAsName, - }) + _spec.ClearField(alert.FieldSourceAsNumber, field.TypeString) } if auo.mutation.SourceAsNameCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceAsName, - }) - } - if value, ok := auo.mutation.SourceCountry(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceCountry, - }) + _spec.ClearField(alert.FieldSourceAsName, field.TypeString) } if auo.mutation.SourceCountryCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceCountry, - }) - } - if value, ok := auo.mutation.SourceLatitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := auo.mutation.AddedSourceLatitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLatitude, - }) + _spec.ClearField(alert.FieldSourceCountry, field.TypeString) } if auo.mutation.SourceLatitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLatitude, - }) - } - if value, ok := auo.mutation.SourceLongitude(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := auo.mutation.AddedSourceLongitude(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Value: value, - Column: alert.FieldSourceLongitude, - }) + _spec.ClearField(alert.FieldSourceLatitude, field.TypeFloat32) } if auo.mutation.SourceLongitudeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeFloat32, - Column: alert.FieldSourceLongitude, - }) - } - if value, ok := auo.mutation.SourceScope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceScope, - }) + _spec.ClearField(alert.FieldSourceLongitude, field.TypeFloat32) } if auo.mutation.SourceScopeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceScope, - }) - } - if value, ok := auo.mutation.SourceValue(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldSourceValue, - }) + _spec.ClearField(alert.FieldSourceScope, field.TypeString) } if auo.mutation.SourceValueCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldSourceValue, - }) - } - if value, ok := auo.mutation.Capacity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) - } - if value, ok := auo.mutation.AddedCapacity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Value: value, - Column: alert.FieldCapacity, - }) + _spec.ClearField(alert.FieldSourceValue, field.TypeString) } if auo.mutation.CapacityCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt32, - Column: alert.FieldCapacity, - }) - } - if value, ok := auo.mutation.LeakSpeed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldLeakSpeed, - }) + _spec.ClearField(alert.FieldCapacity, field.TypeInt32) } if auo.mutation.LeakSpeedCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldLeakSpeed, - }) - } - if value, ok := auo.mutation.ScenarioVersion(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioVersion, - }) + _spec.ClearField(alert.FieldLeakSpeed, field.TypeString) } if auo.mutation.ScenarioVersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioVersion, - }) - } - if value, ok := auo.mutation.ScenarioHash(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldScenarioHash, - }) + _spec.ClearField(alert.FieldScenarioVersion, field.TypeString) } if auo.mutation.ScenarioHashCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldScenarioHash, - }) - } - if value, ok := auo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: alert.FieldSimulated, - }) - } - if value, ok := auo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldScenarioHash, field.TypeString) } if auo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: alert.FieldUUID, - }) + _spec.ClearField(alert.FieldUUID, field.TypeString) + } + if auo.mutation.RemediationCleared() { + _spec.ClearField(alert.FieldRemediation, field.TypeBool) } if auo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -2266,10 +758,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2282,10 +771,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2301,10 +787,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2317,10 +800,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2336,10 +816,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.DecisionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2355,10 +832,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2371,10 +845,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2390,10 +861,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.EventsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2409,10 +877,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -2425,10 +890,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2444,10 +906,7 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error Columns: []string{alert.MetasColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -2466,5 +925,6 @@ func (auo *AlertUpdateOne) sqlSave(ctx context.Context) (_node *Alert, err error } return nil, err } + auo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 068fc6c6713..3b4d619e384 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" ) @@ -17,13 +18,13 @@ type Bouncer struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at"` + CreatedAt time.Time `json:"created_at"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at"` + UpdatedAt time.Time `json:"updated_at"` // Name holds the value of the "name" field. Name string `json:"name"` // APIKey holds the value of the "api_key" field. - APIKey string `json:"api_key"` + APIKey string `json:"-"` // Revoked holds the value of the "revoked" field. Revoked bool `json:"revoked"` // IPAddress holds the value of the "ip_address" field. @@ -32,12 +33,17 @@ type Bouncer struct { Type string `json:"type"` // Version holds the value of the "version" field. Version string `json:"version"` - // Until holds the value of the "until" field. - Until time.Time `json:"until"` // LastPull holds the value of the "last_pull" field. - LastPull time.Time `json:"last_pull"` + LastPull *time.Time `json:"last_pull"` // AuthType holds the value of the "auth_type" field. AuthType string `json:"auth_type"` + // Osname holds the value of the "osname" field. + Osname string `json:"osname,omitempty"` + // Osversion holds the value of the "osversion" field. + Osversion string `json:"osversion,omitempty"` + // Featureflags holds the value of the "featureflags" field. + Featureflags string `json:"featureflags,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -49,12 +55,12 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) - case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType: + case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType, bouncer.FieldOsname, bouncer.FieldOsversion, bouncer.FieldFeatureflags: values[i] = new(sql.NullString) - case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull: + case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldLastPull: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Bouncer", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -78,15 +84,13 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - b.CreatedAt = new(time.Time) - *b.CreatedAt = value.Time + b.CreatedAt = value.Time } case bouncer.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - b.UpdatedAt = new(time.Time) - *b.UpdatedAt = value.Time + b.UpdatedAt = value.Time } case bouncer.FieldName: if value, ok := values[i].(*sql.NullString); !ok { @@ -124,17 +128,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.Version = value.String } - case bouncer.FieldUntil: - if value, ok := values[i].(*sql.NullTime); !ok { - return fmt.Errorf("unexpected type %T for field until", values[i]) - } else if value.Valid { - b.Until = value.Time - } case bouncer.FieldLastPull: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field last_pull", values[i]) } else if value.Valid { - b.LastPull = value.Time + b.LastPull = new(time.Time) + *b.LastPull = value.Time } case bouncer.FieldAuthType: if value, ok := values[i].(*sql.NullString); !ok { @@ -142,16 +141,42 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.AuthType = value.String } + case bouncer.FieldOsname: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osname", values[i]) + } else if value.Valid { + b.Osname = value.String + } + case bouncer.FieldOsversion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osversion", values[i]) + } else if value.Valid { + b.Osversion = value.String + } + case bouncer.FieldFeatureflags: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field featureflags", values[i]) + } else if value.Valid { + b.Featureflags = value.String + } + default: + b.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Bouncer. +// This includes values selected through modifiers, order, etc. +func (b *Bouncer) Value(name string) (ent.Value, error) { + return b.selectValues.Get(name) +} + // Update returns a builder for updating this Bouncer. // Note that you need to call Bouncer.Unwrap() before calling this method if this Bouncer // was returned from a transaction, and the transaction was committed or rolled back. func (b *Bouncer) Update() *BouncerUpdateOne { - return (&BouncerClient{config: b.config}).UpdateOne(b) + return NewBouncerClient(b.config).UpdateOne(b) } // Unwrap unwraps the Bouncer entity that was returned from a transaction after it was closed, @@ -170,21 +195,16 @@ func (b *Bouncer) String() string { var builder strings.Builder builder.WriteString("Bouncer(") builder.WriteString(fmt.Sprintf("id=%v, ", b.ID)) - if v := b.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(b.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := b.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(b.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(b.Name) builder.WriteString(", ") - builder.WriteString("api_key=") - builder.WriteString(b.APIKey) + builder.WriteString("api_key=") builder.WriteString(", ") builder.WriteString("revoked=") builder.WriteString(fmt.Sprintf("%v", b.Revoked)) @@ -198,23 +218,25 @@ func (b *Bouncer) String() string { builder.WriteString("version=") builder.WriteString(b.Version) builder.WriteString(", ") - builder.WriteString("until=") - builder.WriteString(b.Until.Format(time.ANSIC)) - builder.WriteString(", ") - builder.WriteString("last_pull=") - builder.WriteString(b.LastPull.Format(time.ANSIC)) + if v := b.LastPull; v != nil { + builder.WriteString("last_pull=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteString(", ") builder.WriteString("auth_type=") builder.WriteString(b.AuthType) + builder.WriteString(", ") + builder.WriteString("osname=") + builder.WriteString(b.Osname) + builder.WriteString(", ") + builder.WriteString("osversion=") + builder.WriteString(b.Osversion) + builder.WriteString(", ") + builder.WriteString("featureflags=") + builder.WriteString(b.Featureflags) builder.WriteByte(')') return builder.String() } // Bouncers is a parsable slice of Bouncer. type Bouncers []*Bouncer - -func (b Bouncers) config(cfg config) { - for _i := range b { - b[_i].config = cfg - } -} diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index b688594ece4..a6f62aeadd5 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -4,6 +4,8 @@ package bouncer import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -27,12 +29,16 @@ const ( FieldType = "type" // FieldVersion holds the string denoting the version field in the database. FieldVersion = "version" - // FieldUntil holds the string denoting the until field in the database. - FieldUntil = "until" // FieldLastPull holds the string denoting the last_pull field in the database. FieldLastPull = "last_pull" // FieldAuthType holds the string denoting the auth_type field in the database. FieldAuthType = "auth_type" + // FieldOsname holds the string denoting the osname field in the database. + FieldOsname = "osname" + // FieldOsversion holds the string denoting the osversion field in the database. + FieldOsversion = "osversion" + // FieldFeatureflags holds the string denoting the featureflags field in the database. + FieldFeatureflags = "featureflags" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -48,9 +54,11 @@ var Columns = []string{ FieldIPAddress, FieldType, FieldVersion, - FieldUntil, FieldLastPull, FieldAuthType, + FieldOsname, + FieldOsversion, + FieldFeatureflags, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -66,18 +74,85 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time // DefaultIPAddress holds the default value on creation for the "ip_address" field. DefaultIPAddress string - // DefaultUntil holds the default value on creation for the "until" field. - DefaultUntil func() time.Time - // DefaultLastPull holds the default value on creation for the "last_pull" field. - DefaultLastPull func() time.Time // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Bouncer queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByAPIKey orders the results by the api_key field. +func ByAPIKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKey, opts...).ToFunc() +} + +// ByRevoked orders the results by the revoked field. +func ByRevoked(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRevoked, opts...).ToFunc() +} + +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByLastPull orders the results by the last_pull field. +func ByLastPull(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPull, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} + +// ByOsname orders the results by the osname field. +func ByOsname(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsname, opts...).ToFunc() +} + +// ByOsversion orders the results by the osversion field. +func ByOsversion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsversion, opts...).ToFunc() +} + +// ByFeatureflags orders the results by the featureflags field. +func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index 03a543f6d4f..e02199bc0a9 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -11,1128 +11,910 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // APIKey applies equality check predicate on the "api_key" field. It's identical to APIKeyEQ. func APIKey(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // Revoked applies equality check predicate on the "revoked" field. It's identical to RevokedEQ. func Revoked(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. func IPAddress(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) -} - -// Until applies equality check predicate on the "until" field. It's identical to UntilEQ. -func Until(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // LastPull applies equality check predicate on the "last_pull" field. It's identical to LastPullEQ. func LastPull(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) +} + +// Osname applies equality check predicate on the "osname" field. It's identical to OsnameEQ. +func Osname(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsname, v)) +} + +// Osversion applies equality check predicate on the "osversion" field. It's identical to OsversionEQ. +func Osversion(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsversion, v)) +} + +// Featureflags applies equality check predicate on the "featureflags" field. It's identical to FeatureflagsEQ. +func Featureflags(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Bouncer(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Bouncer(sql.FieldLTE(FieldUpdatedAt, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldName, v)) } // APIKeyEQ applies the EQ predicate on the "api_key" field. func APIKeyEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAPIKey, v)) } // APIKeyNEQ applies the NEQ predicate on the "api_key" field. func APIKeyNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAPIKey, v)) } // APIKeyIn applies the In predicate on the "api_key" field. func APIKeyIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAPIKey, vs...)) } // APIKeyNotIn applies the NotIn predicate on the "api_key" field. func APIKeyNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAPIKey), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAPIKey, vs...)) } // APIKeyGT applies the GT predicate on the "api_key" field. func APIKeyGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAPIKey, v)) } // APIKeyGTE applies the GTE predicate on the "api_key" field. func APIKeyGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAPIKey, v)) } // APIKeyLT applies the LT predicate on the "api_key" field. func APIKeyLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAPIKey, v)) } // APIKeyLTE applies the LTE predicate on the "api_key" field. func APIKeyLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAPIKey, v)) } // APIKeyContains applies the Contains predicate on the "api_key" field. func APIKeyContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAPIKey, v)) } // APIKeyHasPrefix applies the HasPrefix predicate on the "api_key" field. func APIKeyHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAPIKey, v)) } // APIKeyHasSuffix applies the HasSuffix predicate on the "api_key" field. func APIKeyHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAPIKey, v)) } // APIKeyEqualFold applies the EqualFold predicate on the "api_key" field. func APIKeyEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAPIKey, v)) } // APIKeyContainsFold applies the ContainsFold predicate on the "api_key" field. func APIKeyContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAPIKey), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAPIKey, v)) } // RevokedEQ applies the EQ predicate on the "revoked" field. func RevokedEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldRevoked, v)) } // RevokedNEQ applies the NEQ predicate on the "revoked" field. func RevokedNEQ(v bool) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRevoked), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldRevoked, v)) } // IPAddressEQ applies the EQ predicate on the "ip_address" field. func IPAddressEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldIPAddress, v)) } // IPAddressNEQ applies the NEQ predicate on the "ip_address" field. func IPAddressNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldIPAddress, v)) } // IPAddressIn applies the In predicate on the "ip_address" field. func IPAddressIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldIPAddress, vs...)) } // IPAddressNotIn applies the NotIn predicate on the "ip_address" field. func IPAddressNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPAddress), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldIPAddress, vs...)) } // IPAddressGT applies the GT predicate on the "ip_address" field. func IPAddressGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldIPAddress, v)) } // IPAddressGTE applies the GTE predicate on the "ip_address" field. func IPAddressGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldIPAddress, v)) } // IPAddressLT applies the LT predicate on the "ip_address" field. func IPAddressLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldIPAddress, v)) } // IPAddressLTE applies the LTE predicate on the "ip_address" field. func IPAddressLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldIPAddress, v)) } // IPAddressContains applies the Contains predicate on the "ip_address" field. func IPAddressContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldIPAddress, v)) } // IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. func IPAddressHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldIPAddress, v)) } // IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. func IPAddressHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldIPAddress, v)) } // IPAddressIsNil applies the IsNil predicate on the "ip_address" field. func IPAddressIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldIPAddress)) } // IPAddressNotNil applies the NotNil predicate on the "ip_address" field. func IPAddressNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPAddress))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldIPAddress)) } // IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. func IPAddressEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldIPAddress, v)) } // IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. func IPAddressContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIPAddress), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldIPAddress, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldType, v)) } // TypeIsNil applies the IsNil predicate on the "type" field. func TypeIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldType)) } // TypeNotNil applies the NotNil predicate on the "type" field. func TypeNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldType))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldType)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldType, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Bouncer(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) -} - -// UntilEQ applies the EQ predicate on the "until" field. -func UntilEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) -} - -// UntilNEQ applies the NEQ predicate on the "until" field. -func UntilNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) -} - -// UntilIn applies the In predicate on the "until" field. -func UntilIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) -} - -// UntilNotIn applies the NotIn predicate on the "until" field. -func UntilNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) -} - -// UntilGT applies the GT predicate on the "until" field. -func UntilGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) -} - -// UntilGTE applies the GTE predicate on the "until" field. -func UntilGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) -} - -// UntilLT applies the LT predicate on the "until" field. -func UntilLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) -} - -// UntilLTE applies the LTE predicate on the "until" field. -func UntilLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) -} - -// UntilIsNil applies the IsNil predicate on the "until" field. -func UntilIsNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) -} - -// UntilNotNil applies the NotNil predicate on the "until" field. -func UntilNotNil() predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldVersion, v)) } // LastPullEQ applies the EQ predicate on the "last_pull" field. func LastPullEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldLastPull, v)) } // LastPullNEQ applies the NEQ predicate on the "last_pull" field. func LastPullNEQ(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldLastPull, v)) } // LastPullIn applies the In predicate on the "last_pull" field. func LastPullIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldLastPull, vs...)) } // LastPullNotIn applies the NotIn predicate on the "last_pull" field. func LastPullNotIn(vs ...time.Time) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPull), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldLastPull, vs...)) } // LastPullGT applies the GT predicate on the "last_pull" field. func LastPullGT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldLastPull, v)) } // LastPullGTE applies the GTE predicate on the "last_pull" field. func LastPullGTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldLastPull, v)) } // LastPullLT applies the LT predicate on the "last_pull" field. func LastPullLT(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldLastPull, v)) } // LastPullLTE applies the LTE predicate on the "last_pull" field. func LastPullLTE(v time.Time) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPull), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldLastPull, v)) +} + +// LastPullIsNil applies the IsNil predicate on the "last_pull" field. +func LastPullIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldLastPull)) +} + +// LastPullNotNil applies the NotNil predicate on the "last_pull" field. +func LastPullNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldLastPull)) } // AuthTypeEQ applies the EQ predicate on the "auth_type" field. func AuthTypeEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEQ(FieldAuthType, v)) } // AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. func AuthTypeNEQ(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldNEQ(FieldAuthType, v)) } // AuthTypeIn applies the In predicate on the "auth_type" field. func AuthTypeIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldIn(FieldAuthType, vs...)) } // AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. func AuthTypeNotIn(vs ...string) predicate.Bouncer { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) + return predicate.Bouncer(sql.FieldNotIn(FieldAuthType, vs...)) } // AuthTypeGT applies the GT predicate on the "auth_type" field. func AuthTypeGT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGT(FieldAuthType, v)) } // AuthTypeGTE applies the GTE predicate on the "auth_type" field. func AuthTypeGTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldGTE(FieldAuthType, v)) } // AuthTypeLT applies the LT predicate on the "auth_type" field. func AuthTypeLT(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLT(FieldAuthType, v)) } // AuthTypeLTE applies the LTE predicate on the "auth_type" field. func AuthTypeLTE(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldLTE(FieldAuthType, v)) } // AuthTypeContains applies the Contains predicate on the "auth_type" field. func AuthTypeContains(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContains(FieldAuthType, v)) } // AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. func AuthTypeHasPrefix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasPrefix(FieldAuthType, v)) } // AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. func AuthTypeHasSuffix(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldHasSuffix(FieldAuthType, v)) } // AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. func AuthTypeEqualFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldEqualFold(FieldAuthType, v)) } // AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. func AuthTypeContainsFold(v string) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) + return predicate.Bouncer(sql.FieldContainsFold(FieldAuthType, v)) +} + +// OsnameEQ applies the EQ predicate on the "osname" field. +func OsnameEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsname, v)) +} + +// OsnameNEQ applies the NEQ predicate on the "osname" field. +func OsnameNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldOsname, v)) +} + +// OsnameIn applies the In predicate on the "osname" field. +func OsnameIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldOsname, vs...)) +} + +// OsnameNotIn applies the NotIn predicate on the "osname" field. +func OsnameNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldOsname, vs...)) +} + +// OsnameGT applies the GT predicate on the "osname" field. +func OsnameGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldOsname, v)) +} + +// OsnameGTE applies the GTE predicate on the "osname" field. +func OsnameGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldOsname, v)) +} + +// OsnameLT applies the LT predicate on the "osname" field. +func OsnameLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldOsname, v)) +} + +// OsnameLTE applies the LTE predicate on the "osname" field. +func OsnameLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldOsname, v)) +} + +// OsnameContains applies the Contains predicate on the "osname" field. +func OsnameContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldOsname, v)) +} + +// OsnameHasPrefix applies the HasPrefix predicate on the "osname" field. +func OsnameHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldOsname, v)) +} + +// OsnameHasSuffix applies the HasSuffix predicate on the "osname" field. +func OsnameHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldOsname, v)) +} + +// OsnameIsNil applies the IsNil predicate on the "osname" field. +func OsnameIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldOsname)) +} + +// OsnameNotNil applies the NotNil predicate on the "osname" field. +func OsnameNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldOsname)) +} + +// OsnameEqualFold applies the EqualFold predicate on the "osname" field. +func OsnameEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldOsname, v)) +} + +// OsnameContainsFold applies the ContainsFold predicate on the "osname" field. +func OsnameContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldOsname, v)) +} + +// OsversionEQ applies the EQ predicate on the "osversion" field. +func OsversionEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldOsversion, v)) +} + +// OsversionNEQ applies the NEQ predicate on the "osversion" field. +func OsversionNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldOsversion, v)) +} + +// OsversionIn applies the In predicate on the "osversion" field. +func OsversionIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldOsversion, vs...)) +} + +// OsversionNotIn applies the NotIn predicate on the "osversion" field. +func OsversionNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldOsversion, vs...)) +} + +// OsversionGT applies the GT predicate on the "osversion" field. +func OsversionGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldOsversion, v)) +} + +// OsversionGTE applies the GTE predicate on the "osversion" field. +func OsversionGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldOsversion, v)) +} + +// OsversionLT applies the LT predicate on the "osversion" field. +func OsversionLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldOsversion, v)) +} + +// OsversionLTE applies the LTE predicate on the "osversion" field. +func OsversionLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldOsversion, v)) +} + +// OsversionContains applies the Contains predicate on the "osversion" field. +func OsversionContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldOsversion, v)) +} + +// OsversionHasPrefix applies the HasPrefix predicate on the "osversion" field. +func OsversionHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldOsversion, v)) +} + +// OsversionHasSuffix applies the HasSuffix predicate on the "osversion" field. +func OsversionHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldOsversion, v)) +} + +// OsversionIsNil applies the IsNil predicate on the "osversion" field. +func OsversionIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldOsversion)) +} + +// OsversionNotNil applies the NotNil predicate on the "osversion" field. +func OsversionNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldOsversion)) +} + +// OsversionEqualFold applies the EqualFold predicate on the "osversion" field. +func OsversionEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldOsversion, v)) +} + +// OsversionContainsFold applies the ContainsFold predicate on the "osversion" field. +func OsversionContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldOsversion, v)) +} + +// FeatureflagsEQ applies the EQ predicate on the "featureflags" field. +func FeatureflagsEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) +} + +// FeatureflagsNEQ applies the NEQ predicate on the "featureflags" field. +func FeatureflagsNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldFeatureflags, v)) +} + +// FeatureflagsIn applies the In predicate on the "featureflags" field. +func FeatureflagsIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsNotIn applies the NotIn predicate on the "featureflags" field. +func FeatureflagsNotIn(vs ...string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsGT applies the GT predicate on the "featureflags" field. +func FeatureflagsGT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGT(FieldFeatureflags, v)) +} + +// FeatureflagsGTE applies the GTE predicate on the "featureflags" field. +func FeatureflagsGTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldGTE(FieldFeatureflags, v)) +} + +// FeatureflagsLT applies the LT predicate on the "featureflags" field. +func FeatureflagsLT(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLT(FieldFeatureflags, v)) +} + +// FeatureflagsLTE applies the LTE predicate on the "featureflags" field. +func FeatureflagsLTE(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldLTE(FieldFeatureflags, v)) +} + +// FeatureflagsContains applies the Contains predicate on the "featureflags" field. +func FeatureflagsContains(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContains(FieldFeatureflags, v)) +} + +// FeatureflagsHasPrefix applies the HasPrefix predicate on the "featureflags" field. +func FeatureflagsHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasPrefix(FieldFeatureflags, v)) +} + +// FeatureflagsHasSuffix applies the HasSuffix predicate on the "featureflags" field. +func FeatureflagsHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldHasSuffix(FieldFeatureflags, v)) +} + +// FeatureflagsIsNil applies the IsNil predicate on the "featureflags" field. +func FeatureflagsIsNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldIsNull(FieldFeatureflags)) +} + +// FeatureflagsNotNil applies the NotNil predicate on the "featureflags" field. +func FeatureflagsNotNil() predicate.Bouncer { + return predicate.Bouncer(sql.FieldNotNull(FieldFeatureflags)) +} + +// FeatureflagsEqualFold applies the EqualFold predicate on the "featureflags" field. +func FeatureflagsEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEqualFold(FieldFeatureflags, v)) +} + +// FeatureflagsContainsFold applies the ContainsFold predicate on the "featureflags" field. +func FeatureflagsContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Bouncer(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Bouncer) predicate.Bouncer { - return predicate.Bouncer(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Bouncer(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 685ce089d1e..29b23f87cf1 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -108,20 +108,6 @@ func (bc *BouncerCreate) SetNillableVersion(s *string) *BouncerCreate { return bc } -// SetUntil sets the "until" field. -func (bc *BouncerCreate) SetUntil(t time.Time) *BouncerCreate { - bc.mutation.SetUntil(t) - return bc -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (bc *BouncerCreate) SetNillableUntil(t *time.Time) *BouncerCreate { - if t != nil { - bc.SetUntil(*t) - } - return bc -} - // SetLastPull sets the "last_pull" field. func (bc *BouncerCreate) SetLastPull(t time.Time) *BouncerCreate { bc.mutation.SetLastPull(t) @@ -150,6 +136,48 @@ func (bc *BouncerCreate) SetNillableAuthType(s *string) *BouncerCreate { return bc } +// SetOsname sets the "osname" field. +func (bc *BouncerCreate) SetOsname(s string) *BouncerCreate { + bc.mutation.SetOsname(s) + return bc +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableOsname(s *string) *BouncerCreate { + if s != nil { + bc.SetOsname(*s) + } + return bc +} + +// SetOsversion sets the "osversion" field. +func (bc *BouncerCreate) SetOsversion(s string) *BouncerCreate { + bc.mutation.SetOsversion(s) + return bc +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableOsversion(s *string) *BouncerCreate { + if s != nil { + bc.SetOsversion(*s) + } + return bc +} + +// SetFeatureflags sets the "featureflags" field. +func (bc *BouncerCreate) SetFeatureflags(s string) *BouncerCreate { + bc.mutation.SetFeatureflags(s) + return bc +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate { + if s != nil { + bc.SetFeatureflags(*s) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -157,50 +185,8 @@ func (bc *BouncerCreate) Mutation() *BouncerMutation { // Save creates the Bouncer in the database. func (bc *BouncerCreate) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) bc.defaults() - if len(bc.hooks) == 0 { - if err = bc.check(); err != nil { - return nil, err - } - node, err = bc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = bc.check(); err != nil { - return nil, err - } - bc.mutation = mutation - if node, err = bc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(bc.hooks) - 1; i >= 0; i-- { - if bc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, bc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, bc.sqlSave, bc.mutation, bc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -239,14 +225,6 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultIPAddress bc.mutation.SetIPAddress(v) } - if _, ok := bc.mutation.Until(); !ok { - v := bouncer.DefaultUntil() - bc.mutation.SetUntil(v) - } - if _, ok := bc.mutation.LastPull(); !ok { - v := bouncer.DefaultLastPull() - bc.mutation.SetLastPull(v) - } if _, ok := bc.mutation.AuthType(); !ok { v := bouncer.DefaultAuthType bc.mutation.SetAuthType(v) @@ -255,6 +233,12 @@ func (bc *BouncerCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (bc *BouncerCreate) check() error { + if _, ok := bc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Bouncer.created_at"`)} + } + if _, ok := bc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Bouncer.updated_at"`)} + } if _, ok := bc.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Bouncer.name"`)} } @@ -264,9 +248,6 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.Revoked(); !ok { return &ValidationError{Name: "revoked", err: errors.New(`ent: missing required field "Bouncer.revoked"`)} } - if _, ok := bc.mutation.LastPull(); !ok { - return &ValidationError{Name: "last_pull", err: errors.New(`ent: missing required field "Bouncer.last_pull"`)} - } if _, ok := bc.mutation.AuthType(); !ok { return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} } @@ -274,6 +255,9 @@ func (bc *BouncerCreate) check() error { } func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { + if err := bc.check(); err != nil { + return nil, err + } _node, _spec := bc.createSpec() if err := sqlgraph.CreateNode(ctx, bc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -283,119 +267,83 @@ func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + bc.mutation.id = &_node.ID + bc.mutation.done = true return _node, nil } func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { var ( _node = &Bouncer{config: bc.config} - _spec = &sqlgraph.CreateSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) ) if value, ok := bc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := bc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := bc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldName, field.TypeString, value) _node.Name = value } if value, ok := bc.mutation.APIKey(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) _node.APIKey = value } if value, ok := bc.mutation.Revoked(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) _node.Revoked = value } if value, ok := bc.mutation.IPAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) _node.IPAddress = value } if value, ok := bc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) _node.Type = value } if value, ok := bc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) _node.Version = value } - if value, ok := bc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - _node.Until = value - } if value, ok := bc.mutation.LastPull(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) - _node.LastPull = value + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + _node.LastPull = &value } if value, ok := bc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) _node.AuthType = value } + if value, ok := bc.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + _node.Osname = value + } + if value, ok := bc.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + _node.Osversion = value + } + if value, ok := bc.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + _node.Featureflags = value + } return _node, _spec } // BouncerCreateBulk is the builder for creating many Bouncer entities in bulk. type BouncerCreateBulk struct { config + err error builders []*BouncerCreate } // Save creates the Bouncer entities in the database. func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { + if bcb.err != nil { + return nil, bcb.err + } specs := make([]*sqlgraph.CreateSpec, len(bcb.builders)) nodes := make([]*Bouncer, len(bcb.builders)) mutators := make([]Mutator, len(bcb.builders)) @@ -412,8 +360,8 @@ func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, bcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/bouncer_delete.go b/pkg/database/ent/bouncer_delete.go index 6bfb9459190..bf459e77e28 100644 --- a/pkg/database/ent/bouncer_delete.go +++ b/pkg/database/ent/bouncer_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (bd *BouncerDelete) Where(ps ...predicate.Bouncer) *BouncerDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (bd *BouncerDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(bd.hooks) == 0 { - affected, err = bd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bd.mutation = mutation - affected, err = bd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(bd.hooks) - 1; i >= 0; i-- { - if bd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bd.sqlExec, bd.mutation, bd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (bd *BouncerDelete) ExecX(ctx context.Context) int { } func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + bd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type BouncerDeleteOne struct { bd *BouncerDelete } +// Where appends a list predicates to the BouncerDelete builder. +func (bdo *BouncerDeleteOne) Where(ps ...predicate.Bouncer) *BouncerDeleteOne { + bdo.bd.mutation.Where(ps...) + return bdo +} + // Exec executes the deletion query. func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { n, err := bdo.bd.Exec(ctx) @@ -111,5 +82,7 @@ func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (bdo *BouncerDeleteOne) ExecX(ctx context.Context) { - bdo.bd.ExecX(ctx) + if err := bdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/bouncer_query.go b/pkg/database/ent/bouncer_query.go index 2747a3e0b3a..ea2b7495733 100644 --- a/pkg/database/ent/bouncer_query.go +++ b/pkg/database/ent/bouncer_query.go @@ -17,11 +17,9 @@ import ( // BouncerQuery is the builder for querying Bouncer entities. type BouncerQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []bouncer.OrderOption + inters []Interceptor predicates []predicate.Bouncer // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (bq *BouncerQuery) Where(ps ...predicate.Bouncer) *BouncerQuery { return bq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (bq *BouncerQuery) Limit(limit int) *BouncerQuery { - bq.limit = &limit + bq.ctx.Limit = &limit return bq } -// Offset adds an offset step to the query. +// Offset to start from. func (bq *BouncerQuery) Offset(offset int) *BouncerQuery { - bq.offset = &offset + bq.ctx.Offset = &offset return bq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (bq *BouncerQuery) Unique(unique bool) *BouncerQuery { - bq.unique = &unique + bq.ctx.Unique = &unique return bq } -// Order adds an order step to the query. -func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { +// Order specifies how the records should be ordered. +func (bq *BouncerQuery) Order(o ...bouncer.OrderOption) *BouncerQuery { bq.order = append(bq.order, o...) return bq } @@ -62,7 +60,7 @@ func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery { // First returns the first Bouncer entity from the query. // Returns a *NotFoundError when no Bouncer was found. func (bq *BouncerQuery) First(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(1).All(ctx) + nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (bq *BouncerQuery) FirstX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no Bouncer ID was found. func (bq *BouncerQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(1).IDs(ctx); err != nil { + if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (bq *BouncerQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Bouncer entity is found. // Returns a *NotFoundError when no Bouncer entities are found. func (bq *BouncerQuery) Only(ctx context.Context) (*Bouncer, error) { - nodes, err := bq.Limit(2).All(ctx) + nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (bq *BouncerQuery) OnlyX(ctx context.Context) *Bouncer { // Returns a *NotFoundError when no entities are found. func (bq *BouncerQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = bq.Limit(2).IDs(ctx); err != nil { + if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (bq *BouncerQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Bouncers. func (bq *BouncerQuery) All(ctx context.Context) ([]*Bouncer, error) { + ctx = setContextOp(ctx, bq.ctx, "All") if err := bq.prepareQuery(ctx); err != nil { return nil, err } - return bq.sqlAll(ctx) + qr := querierAll[[]*Bouncer, *BouncerQuery]() + return withInterceptors[[]*Bouncer](ctx, bq, qr, bq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (bq *BouncerQuery) AllX(ctx context.Context) []*Bouncer { } // IDs executes the query and returns a list of Bouncer IDs. -func (bq *BouncerQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { +func (bq *BouncerQuery) IDs(ctx context.Context) (ids []int, err error) { + if bq.ctx.Unique == nil && bq.path != nil { + bq.Unique(true) + } + ctx = setContextOp(ctx, bq.ctx, "IDs") + if err = bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (bq *BouncerQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (bq *BouncerQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, bq.ctx, "Count") if err := bq.prepareQuery(ctx); err != nil { return 0, err } - return bq.sqlCount(ctx) + return withInterceptors[int](ctx, bq, querierCount[*BouncerQuery](), bq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (bq *BouncerQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (bq *BouncerQuery) Exist(ctx context.Context) (bool, error) { - if err := bq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, bq.ctx, "Exist") + switch _, err := bq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return bq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { } return &BouncerQuery{ config: bq.config, - limit: bq.limit, - offset: bq.offset, - order: append([]OrderFunc{}, bq.order...), + ctx: bq.ctx.Clone(), + order: append([]bouncer.OrderOption{}, bq.order...), + inters: append([]Interceptor{}, bq.inters...), predicates: append([]predicate.Bouncer{}, bq.predicates...), // clone intermediate query. - sql: bq.sql.Clone(), - path: bq.path, - unique: bq.unique, + sql: bq.sql.Clone(), + path: bq.path, } } @@ -262,16 +270,11 @@ func (bq *BouncerQuery) Clone() *BouncerQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy { - grbuild := &BouncerGroupBy{config: bq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := bq.prepareQuery(ctx); err != nil { - return nil, err - } - return bq.sqlQuery(ctx), nil - } + bq.ctx.Fields = append([]string{field}, fields...) + grbuild := &BouncerGroupBy{build: bq} + grbuild.flds = &bq.ctx.Fields grbuild.label = bouncer.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy // Select(bouncer.FieldCreatedAt). // Scan(ctx, &v) func (bq *BouncerQuery) Select(fields ...string) *BouncerSelect { - bq.fields = append(bq.fields, fields...) - selbuild := &BouncerSelect{BouncerQuery: bq} - selbuild.label = bouncer.Label - selbuild.flds, selbuild.scan = &bq.fields, selbuild.Scan - return selbuild + bq.ctx.Fields = append(bq.ctx.Fields, fields...) + sbuild := &BouncerSelect{BouncerQuery: bq} + sbuild.label = bouncer.Label + sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a BouncerSelect configured with the given aggregations. +func (bq *BouncerQuery) Aggregate(fns ...AggregateFunc) *BouncerSelect { + return bq.Select().Aggregate(fns...) } func (bq *BouncerQuery) prepareQuery(ctx context.Context) error { - for _, f := range bq.fields { + for _, inter := range bq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, bq); err != nil { + return err + } + } + } + for _, f := range bq.ctx.Fields { if !bouncer.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Boun func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) { _spec := bq.querySpec() - _spec.Node.Columns = bq.fields - if len(bq.fields) > 0 { - _spec.Unique = bq.unique != nil && *bq.unique + _spec.Node.Columns = bq.ctx.Fields + if len(bq.ctx.Fields) > 0 { + _spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique } return sqlgraph.CountNodes(ctx, bq.driver, _spec) } -func (bq *BouncerQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := bq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - From: bq.sql, - Unique: true, - } - if unique := bq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) + _spec.From = bq.sql + if unique := bq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if bq.path != nil { + _spec.Unique = true } - if fields := bq.fields; len(fields) > 0 { + if fields := bq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, bouncer.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := bq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec { func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(bq.driver.Dialect()) t1 := builder.Table(bouncer.Table) - columns := bq.fields + columns := bq.ctx.Fields if len(columns) == 0 { columns = bouncer.Columns } @@ -416,7 +415,7 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = bq.sql selector.Select(selector.Columns(columns...)...) } - if bq.unique != nil && *bq.unique { + if bq.ctx.Unique != nil && *bq.ctx.Unique { selector.Distinct() } for _, p := range bq.predicates { @@ -425,12 +424,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range bq.order { p(selector) } - if offset := bq.offset; offset != nil { + if offset := bq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := bq.limit; limit != nil { + if limit := bq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector { // BouncerGroupBy is the group-by builder for Bouncer entities. type BouncerGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *BouncerQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (bgb *BouncerGroupBy) Aggregate(fns ...AggregateFunc) *BouncerGroupBy { return bgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (bgb *BouncerGroupBy) Scan(ctx context.Context, v any) error { - query, err := bgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy") + if err := bgb.build.prepareQuery(ctx); err != nil { return err } - bgb.sql = query - return bgb.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerGroupBy](ctx, bgb.build, bgb, bgb.build.inters, v) } -func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range bgb.fields { - if !bouncer.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(bgb.fns)) + for _, fn := range bgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*bgb.flds)+len(bgb.fns)) + for _, f := range *bgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := bgb.sqlQuery() + selector.GroupBy(selector.Columns(*bgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := bgb.driver.Query(ctx, query, args, rows); err != nil { + if err := bgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (bgb *BouncerGroupBy) sqlQuery() *sql.Selector { - selector := bgb.sql.Select() - aggregation := make([]string, 0, len(bgb.fns)) - for _, fn := range bgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(bgb.fields)+len(bgb.fns)) - for _, f := range bgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(bgb.fields...)...) -} - // BouncerSelect is the builder for selecting fields of Bouncer entities. type BouncerSelect struct { *BouncerQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (bs *BouncerSelect) Aggregate(fns ...AggregateFunc) *BouncerSelect { + bs.fns = append(bs.fns, fns...) + return bs } // Scan applies the selector query and scans the result into the given value. func (bs *BouncerSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, bs.ctx, "Select") if err := bs.prepareQuery(ctx); err != nil { return err } - bs.sql = bs.BouncerQuery.sqlQuery(ctx) - return bs.sqlScan(ctx, v) + return scanWithInterceptors[*BouncerQuery, *BouncerSelect](ctx, bs.BouncerQuery, bs, bs.inters, v) } -func (bs *BouncerSelect) sqlScan(ctx context.Context, v any) error { +func (bs *BouncerSelect) sqlScan(ctx context.Context, root *BouncerQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(bs.fns)) + for _, fn := range bs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*bs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := bs.sql.Query() + query, args := selector.Query() if err := bs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/bouncer_update.go b/pkg/database/ent/bouncer_update.go index acf48dedeec..620b006a49a 100644 --- a/pkg/database/ent/bouncer_update.go +++ b/pkg/database/ent/bouncer_update.go @@ -28,48 +28,40 @@ func (bu *BouncerUpdate) Where(ps ...predicate.Bouncer) *BouncerUpdate { return bu } -// SetCreatedAt sets the "created_at" field. -func (bu *BouncerUpdate) SetCreatedAt(t time.Time) *BouncerUpdate { - bu.mutation.SetCreatedAt(t) - return bu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (bu *BouncerUpdate) ClearCreatedAt() *BouncerUpdate { - bu.mutation.ClearCreatedAt() - return bu -} - // SetUpdatedAt sets the "updated_at" field. func (bu *BouncerUpdate) SetUpdatedAt(t time.Time) *BouncerUpdate { bu.mutation.SetUpdatedAt(t) return bu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (bu *BouncerUpdate) ClearUpdatedAt() *BouncerUpdate { - bu.mutation.ClearUpdatedAt() - return bu -} - -// SetName sets the "name" field. -func (bu *BouncerUpdate) SetName(s string) *BouncerUpdate { - bu.mutation.SetName(s) - return bu -} - // SetAPIKey sets the "api_key" field. func (bu *BouncerUpdate) SetAPIKey(s string) *BouncerUpdate { bu.mutation.SetAPIKey(s) return bu } +// SetNillableAPIKey sets the "api_key" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableAPIKey(s *string) *BouncerUpdate { + if s != nil { + bu.SetAPIKey(*s) + } + return bu +} + // SetRevoked sets the "revoked" field. func (bu *BouncerUpdate) SetRevoked(b bool) *BouncerUpdate { bu.mutation.SetRevoked(b) return bu } +// SetNillableRevoked sets the "revoked" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableRevoked(b *bool) *BouncerUpdate { + if b != nil { + bu.SetRevoked(*b) + } + return bu +} + // SetIPAddress sets the "ip_address" field. func (bu *BouncerUpdate) SetIPAddress(s string) *BouncerUpdate { bu.mutation.SetIPAddress(s) @@ -130,26 +122,6 @@ func (bu *BouncerUpdate) ClearVersion() *BouncerUpdate { return bu } -// SetUntil sets the "until" field. -func (bu *BouncerUpdate) SetUntil(t time.Time) *BouncerUpdate { - bu.mutation.SetUntil(t) - return bu -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (bu *BouncerUpdate) SetNillableUntil(t *time.Time) *BouncerUpdate { - if t != nil { - bu.SetUntil(*t) - } - return bu -} - -// ClearUntil clears the value of the "until" field. -func (bu *BouncerUpdate) ClearUntil() *BouncerUpdate { - bu.mutation.ClearUntil() - return bu -} - // SetLastPull sets the "last_pull" field. func (bu *BouncerUpdate) SetLastPull(t time.Time) *BouncerUpdate { bu.mutation.SetLastPull(t) @@ -164,6 +136,12 @@ func (bu *BouncerUpdate) SetNillableLastPull(t *time.Time) *BouncerUpdate { return bu } +// ClearLastPull clears the value of the "last_pull" field. +func (bu *BouncerUpdate) ClearLastPull() *BouncerUpdate { + bu.mutation.ClearLastPull() + return bu +} + // SetAuthType sets the "auth_type" field. func (bu *BouncerUpdate) SetAuthType(s string) *BouncerUpdate { bu.mutation.SetAuthType(s) @@ -178,6 +156,66 @@ func (bu *BouncerUpdate) SetNillableAuthType(s *string) *BouncerUpdate { return bu } +// SetOsname sets the "osname" field. +func (bu *BouncerUpdate) SetOsname(s string) *BouncerUpdate { + bu.mutation.SetOsname(s) + return bu +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableOsname(s *string) *BouncerUpdate { + if s != nil { + bu.SetOsname(*s) + } + return bu +} + +// ClearOsname clears the value of the "osname" field. +func (bu *BouncerUpdate) ClearOsname() *BouncerUpdate { + bu.mutation.ClearOsname() + return bu +} + +// SetOsversion sets the "osversion" field. +func (bu *BouncerUpdate) SetOsversion(s string) *BouncerUpdate { + bu.mutation.SetOsversion(s) + return bu +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableOsversion(s *string) *BouncerUpdate { + if s != nil { + bu.SetOsversion(*s) + } + return bu +} + +// ClearOsversion clears the value of the "osversion" field. +func (bu *BouncerUpdate) ClearOsversion() *BouncerUpdate { + bu.mutation.ClearOsversion() + return bu +} + +// SetFeatureflags sets the "featureflags" field. +func (bu *BouncerUpdate) SetFeatureflags(s string) *BouncerUpdate { + bu.mutation.SetFeatureflags(s) + return bu +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableFeatureflags(s *string) *BouncerUpdate { + if s != nil { + bu.SetFeatureflags(*s) + } + return bu +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (bu *BouncerUpdate) ClearFeatureflags() *BouncerUpdate { + bu.mutation.ClearFeatureflags() + return bu +} + // Mutation returns the BouncerMutation object of the builder. func (bu *BouncerUpdate) Mutation() *BouncerMutation { return bu.mutation @@ -185,35 +223,8 @@ func (bu *BouncerUpdate) Mutation() *BouncerMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (bu *BouncerUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) bu.defaults() - if len(bu.hooks) == 0 { - affected, err = bu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - bu.mutation = mutation - affected, err = bu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(bu.hooks) - 1; i >= 0; i-- { - if bu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = bu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, bu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, bu.sqlSave, bu.mutation, bu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -240,27 +251,14 @@ func (bu *BouncerUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (bu *BouncerUpdate) defaults() { - if _, ok := bu.mutation.CreatedAt(); !ok && !bu.mutation.CreatedAtCleared() { - v := bouncer.UpdateDefaultCreatedAt() - bu.mutation.SetCreatedAt(v) - } - if _, ok := bu.mutation.UpdatedAt(); !ok && !bu.mutation.UpdatedAtCleared() { + if _, ok := bu.mutation.UpdatedAt(); !ok { v := bouncer.UpdateDefaultUpdatedAt() bu.mutation.SetUpdatedAt(v) } } func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) if ps := bu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -268,118 +266,59 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := bu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - } - if bu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) - } if value, ok := bu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - } - if bu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) - } - if value, ok := bu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if value, ok := bu.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := bu.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := bu.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if bu.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := bu.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if bu.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := bu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if bu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) - } - if value, ok := bu.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - } - if bu.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := bu.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + } + if bu.mutation.LastPullCleared() { + _spec.ClearField(bouncer.FieldLastPull, field.TypeTime) } if value, ok := bu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) + } + if value, ok := bu.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + } + if bu.mutation.OsnameCleared() { + _spec.ClearField(bouncer.FieldOsname, field.TypeString) + } + if value, ok := bu.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + } + if bu.mutation.OsversionCleared() { + _spec.ClearField(bouncer.FieldOsversion, field.TypeString) + } + if value, ok := bu.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + } + if bu.mutation.FeatureflagsCleared() { + _spec.ClearField(bouncer.FieldFeatureflags, field.TypeString) } if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -389,6 +328,7 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + bu.mutation.done = true return n, nil } @@ -400,48 +340,40 @@ type BouncerUpdateOne struct { mutation *BouncerMutation } -// SetCreatedAt sets the "created_at" field. -func (buo *BouncerUpdateOne) SetCreatedAt(t time.Time) *BouncerUpdateOne { - buo.mutation.SetCreatedAt(t) - return buo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (buo *BouncerUpdateOne) ClearCreatedAt() *BouncerUpdateOne { - buo.mutation.ClearCreatedAt() - return buo -} - // SetUpdatedAt sets the "updated_at" field. func (buo *BouncerUpdateOne) SetUpdatedAt(t time.Time) *BouncerUpdateOne { buo.mutation.SetUpdatedAt(t) return buo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (buo *BouncerUpdateOne) ClearUpdatedAt() *BouncerUpdateOne { - buo.mutation.ClearUpdatedAt() - return buo -} - -// SetName sets the "name" field. -func (buo *BouncerUpdateOne) SetName(s string) *BouncerUpdateOne { - buo.mutation.SetName(s) - return buo -} - // SetAPIKey sets the "api_key" field. func (buo *BouncerUpdateOne) SetAPIKey(s string) *BouncerUpdateOne { buo.mutation.SetAPIKey(s) return buo } +// SetNillableAPIKey sets the "api_key" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableAPIKey(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetAPIKey(*s) + } + return buo +} + // SetRevoked sets the "revoked" field. func (buo *BouncerUpdateOne) SetRevoked(b bool) *BouncerUpdateOne { buo.mutation.SetRevoked(b) return buo } +// SetNillableRevoked sets the "revoked" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableRevoked(b *bool) *BouncerUpdateOne { + if b != nil { + buo.SetRevoked(*b) + } + return buo +} + // SetIPAddress sets the "ip_address" field. func (buo *BouncerUpdateOne) SetIPAddress(s string) *BouncerUpdateOne { buo.mutation.SetIPAddress(s) @@ -502,26 +434,6 @@ func (buo *BouncerUpdateOne) ClearVersion() *BouncerUpdateOne { return buo } -// SetUntil sets the "until" field. -func (buo *BouncerUpdateOne) SetUntil(t time.Time) *BouncerUpdateOne { - buo.mutation.SetUntil(t) - return buo -} - -// SetNillableUntil sets the "until" field if the given value is not nil. -func (buo *BouncerUpdateOne) SetNillableUntil(t *time.Time) *BouncerUpdateOne { - if t != nil { - buo.SetUntil(*t) - } - return buo -} - -// ClearUntil clears the value of the "until" field. -func (buo *BouncerUpdateOne) ClearUntil() *BouncerUpdateOne { - buo.mutation.ClearUntil() - return buo -} - // SetLastPull sets the "last_pull" field. func (buo *BouncerUpdateOne) SetLastPull(t time.Time) *BouncerUpdateOne { buo.mutation.SetLastPull(t) @@ -536,6 +448,12 @@ func (buo *BouncerUpdateOne) SetNillableLastPull(t *time.Time) *BouncerUpdateOne return buo } +// ClearLastPull clears the value of the "last_pull" field. +func (buo *BouncerUpdateOne) ClearLastPull() *BouncerUpdateOne { + buo.mutation.ClearLastPull() + return buo +} + // SetAuthType sets the "auth_type" field. func (buo *BouncerUpdateOne) SetAuthType(s string) *BouncerUpdateOne { buo.mutation.SetAuthType(s) @@ -550,11 +468,77 @@ func (buo *BouncerUpdateOne) SetNillableAuthType(s *string) *BouncerUpdateOne { return buo } +// SetOsname sets the "osname" field. +func (buo *BouncerUpdateOne) SetOsname(s string) *BouncerUpdateOne { + buo.mutation.SetOsname(s) + return buo +} + +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableOsname(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetOsname(*s) + } + return buo +} + +// ClearOsname clears the value of the "osname" field. +func (buo *BouncerUpdateOne) ClearOsname() *BouncerUpdateOne { + buo.mutation.ClearOsname() + return buo +} + +// SetOsversion sets the "osversion" field. +func (buo *BouncerUpdateOne) SetOsversion(s string) *BouncerUpdateOne { + buo.mutation.SetOsversion(s) + return buo +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableOsversion(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetOsversion(*s) + } + return buo +} + +// ClearOsversion clears the value of the "osversion" field. +func (buo *BouncerUpdateOne) ClearOsversion() *BouncerUpdateOne { + buo.mutation.ClearOsversion() + return buo +} + +// SetFeatureflags sets the "featureflags" field. +func (buo *BouncerUpdateOne) SetFeatureflags(s string) *BouncerUpdateOne { + buo.mutation.SetFeatureflags(s) + return buo +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableFeatureflags(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetFeatureflags(*s) + } + return buo +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (buo *BouncerUpdateOne) ClearFeatureflags() *BouncerUpdateOne { + buo.mutation.ClearFeatureflags() + return buo +} + // Mutation returns the BouncerMutation object of the builder. func (buo *BouncerUpdateOne) Mutation() *BouncerMutation { return buo.mutation } +// Where appends a list predicates to the BouncerUpdate builder. +func (buo *BouncerUpdateOne) Where(ps ...predicate.Bouncer) *BouncerUpdateOne { + buo.mutation.Where(ps...) + return buo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpdateOne { @@ -564,41 +548,8 @@ func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpda // Save executes the query and returns the updated Bouncer entity. func (buo *BouncerUpdateOne) Save(ctx context.Context) (*Bouncer, error) { - var ( - err error - node *Bouncer - ) buo.defaults() - if len(buo.hooks) == 0 { - node, err = buo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - buo.mutation = mutation - node, err = buo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(buo.hooks) - 1; i >= 0; i-- { - if buo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = buo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, buo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Bouncer) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, buo.sqlSave, buo.mutation, buo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -625,27 +576,14 @@ func (buo *BouncerUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (buo *BouncerUpdateOne) defaults() { - if _, ok := buo.mutation.CreatedAt(); !ok && !buo.mutation.CreatedAtCleared() { - v := bouncer.UpdateDefaultCreatedAt() - buo.mutation.SetCreatedAt(v) - } - if _, ok := buo.mutation.UpdatedAt(); !ok && !buo.mutation.UpdatedAtCleared() { + if _, ok := buo.mutation.UpdatedAt(); !ok { v := bouncer.UpdateDefaultUpdatedAt() buo.mutation.SetUpdatedAt(v) } } func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: bouncer.Table, - Columns: bouncer.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: bouncer.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt)) id, ok := buo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Bouncer.id" for update`)} @@ -670,118 +608,59 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } } } - if value, ok := buo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldCreatedAt, - }) - } - if buo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldCreatedAt, - }) - } if value, ok := buo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUpdatedAt, - }) - } - if buo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUpdatedAt, - }) - } - if value, ok := buo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldName, - }) + _spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value) } if value, ok := buo.mutation.APIKey(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAPIKey, - }) + _spec.SetField(bouncer.FieldAPIKey, field.TypeString, value) } if value, ok := buo.mutation.Revoked(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: bouncer.FieldRevoked, - }) + _spec.SetField(bouncer.FieldRevoked, field.TypeBool, value) } if value, ok := buo.mutation.IPAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldIPAddress, - }) + _spec.SetField(bouncer.FieldIPAddress, field.TypeString, value) } if buo.mutation.IPAddressCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldIPAddress, - }) + _spec.ClearField(bouncer.FieldIPAddress, field.TypeString) } if value, ok := buo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldType, - }) + _spec.SetField(bouncer.FieldType, field.TypeString, value) } if buo.mutation.TypeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldType, - }) + _spec.ClearField(bouncer.FieldType, field.TypeString) } if value, ok := buo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldVersion, - }) + _spec.SetField(bouncer.FieldVersion, field.TypeString, value) } if buo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: bouncer.FieldVersion, - }) - } - if value, ok := buo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldUntil, - }) - } - if buo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: bouncer.FieldUntil, - }) + _spec.ClearField(bouncer.FieldVersion, field.TypeString) } if value, ok := buo.mutation.LastPull(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: bouncer.FieldLastPull, - }) + _spec.SetField(bouncer.FieldLastPull, field.TypeTime, value) + } + if buo.mutation.LastPullCleared() { + _spec.ClearField(bouncer.FieldLastPull, field.TypeTime) } if value, ok := buo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: bouncer.FieldAuthType, - }) + _spec.SetField(bouncer.FieldAuthType, field.TypeString, value) + } + if value, ok := buo.mutation.Osname(); ok { + _spec.SetField(bouncer.FieldOsname, field.TypeString, value) + } + if buo.mutation.OsnameCleared() { + _spec.ClearField(bouncer.FieldOsname, field.TypeString) + } + if value, ok := buo.mutation.Osversion(); ok { + _spec.SetField(bouncer.FieldOsversion, field.TypeString, value) + } + if buo.mutation.OsversionCleared() { + _spec.ClearField(bouncer.FieldOsversion, field.TypeString) + } + if value, ok := buo.mutation.Featureflags(); ok { + _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) + } + if buo.mutation.FeatureflagsCleared() { + _spec.ClearField(bouncer.FieldFeatureflags, field.TypeString) } _node = &Bouncer{config: buo.config} _spec.Assign = _node.assignValues @@ -794,5 +673,6 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e } return nil, err } + buo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go index 815b1df6d16..59686102ebe 100644 --- a/pkg/database/ent/client.go +++ b/pkg/database/ent/client.go @@ -7,20 +7,23 @@ import ( "errors" "fmt" "log" + "reflect" "github.com/crowdsecurity/crowdsec/pkg/database/ent/migrate" + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) // Client is the client that holds all ent builders. @@ -38,17 +41,19 @@ type Client struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. Meta *MetaClient + // Metric is the client for interacting with the Metric builders. + Metric *MetricClient } // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} - cfg.options(opts...) - client := &Client{config: cfg} + client := &Client{config: newConfig(opts...)} client.init() return client } @@ -60,8 +65,66 @@ func (c *Client) init() { c.ConfigItem = NewConfigItemClient(c.config) c.Decision = NewDecisionClient(c.config) c.Event = NewEventClient(c.config) + c.Lock = NewLockClient(c.config) c.Machine = NewMachineClient(c.config) c.Meta = NewMetaClient(c.config) + c.Metric = NewMetricClient(c.config) +} + +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// newConfig creates a new config for the client. +func newConfig(opts ...Option) config { + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.options(opts...) + return cfg +} + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } } // Open opens a database/sql.DB specified by the driver name and @@ -80,11 +143,14 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) } } +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + // Tx returns a new transactional client. The provided context // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, ErrTxStarted } tx, err := newTx(ctx, c.driver) if err != nil { @@ -100,8 +166,10 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -126,8 +194,10 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) ConfigItem: NewConfigItemClient(cfg), Decision: NewDecisionClient(cfg), Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), Machine: NewMachineClient(cfg), Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -156,13 +226,49 @@ func (c *Client) Close() error { // Use adds the mutation hooks to all the entity clients. // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { - c.Alert.Use(hooks...) - c.Bouncer.Use(hooks...) - c.ConfigItem.Use(hooks...) - c.Decision.Use(hooks...) - c.Event.Use(hooks...) - c.Machine.Use(hooks...) - c.Meta.Use(hooks...) + for _, n := range []interface{ Use(...Hook) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, c.Metric, + } { + n.Use(hooks...) + } +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + for _, n := range []interface{ Intercept(...Interceptor) }{ + c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, + c.Meta, c.Metric, + } { + n.Intercept(interceptors...) + } +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *AlertMutation: + return c.Alert.mutate(ctx, m) + case *BouncerMutation: + return c.Bouncer.mutate(ctx, m) + case *ConfigItemMutation: + return c.ConfigItem.mutate(ctx, m) + case *DecisionMutation: + return c.Decision.mutate(ctx, m) + case *EventMutation: + return c.Event.mutate(ctx, m) + case *LockMutation: + return c.Lock.mutate(ctx, m) + case *MachineMutation: + return c.Machine.mutate(ctx, m) + case *MetaMutation: + return c.Meta.mutate(ctx, m) + case *MetricMutation: + return c.Metric.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } } // AlertClient is a client for the Alert schema. @@ -181,6 +287,12 @@ func (c *AlertClient) Use(hooks ...Hook) { c.hooks.Alert = append(c.hooks.Alert, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `alert.Intercept(f(g(h())))`. +func (c *AlertClient) Intercept(interceptors ...Interceptor) { + c.inters.Alert = append(c.inters.Alert, interceptors...) +} + // Create returns a builder for creating a Alert entity. func (c *AlertClient) Create() *AlertCreate { mutation := newAlertMutation(c.config, OpCreate) @@ -192,6 +304,21 @@ func (c *AlertClient) CreateBulk(builders ...*AlertCreate) *AlertCreateBulk { return &AlertCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AlertClient) MapCreateBulk(slice any, setFunc func(*AlertCreate, int)) *AlertCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AlertCreateBulk{err: fmt.Errorf("calling to AlertClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AlertCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AlertCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Alert. func (c *AlertClient) Update() *AlertUpdate { mutation := newAlertMutation(c.config, OpUpdate) @@ -221,7 +348,7 @@ func (c *AlertClient) DeleteOne(a *Alert) *AlertDeleteOne { return c.DeleteOneID(a.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { builder := c.Delete().Where(alert.ID(id)) builder.mutation.id = &id @@ -233,6 +360,8 @@ func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne { func (c *AlertClient) Query() *AlertQuery { return &AlertQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAlert}, + inters: c.Interceptors(), } } @@ -252,8 +381,8 @@ func (c *AlertClient) GetX(ctx context.Context, id int) *Alert { // QueryOwner queries the owner edge of a Alert. func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { - query := &MachineQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MachineClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -268,8 +397,8 @@ func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery { // QueryDecisions queries the decisions edge of a Alert. func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { - query := &DecisionQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -284,8 +413,8 @@ func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery { // QueryEvents queries the events edge of a Alert. func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { - query := &EventQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&EventClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -300,8 +429,8 @@ func (c *AlertClient) QueryEvents(a *Alert) *EventQuery { // QueryMetas queries the metas edge of a Alert. func (c *AlertClient) QueryMetas(a *Alert) *MetaQuery { - query := &MetaQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MetaClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(alert.Table, alert.FieldID, id), @@ -319,6 +448,26 @@ func (c *AlertClient) Hooks() []Hook { return c.hooks.Alert } +// Interceptors returns the client interceptors. +func (c *AlertClient) Interceptors() []Interceptor { + return c.inters.Alert +} + +func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AlertCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AlertUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AlertUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AlertDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Alert mutation op: %q", m.Op()) + } +} + // BouncerClient is a client for the Bouncer schema. type BouncerClient struct { config @@ -335,6 +484,12 @@ func (c *BouncerClient) Use(hooks ...Hook) { c.hooks.Bouncer = append(c.hooks.Bouncer, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `bouncer.Intercept(f(g(h())))`. +func (c *BouncerClient) Intercept(interceptors ...Interceptor) { + c.inters.Bouncer = append(c.inters.Bouncer, interceptors...) +} + // Create returns a builder for creating a Bouncer entity. func (c *BouncerClient) Create() *BouncerCreate { mutation := newBouncerMutation(c.config, OpCreate) @@ -346,6 +501,21 @@ func (c *BouncerClient) CreateBulk(builders ...*BouncerCreate) *BouncerCreateBul return &BouncerCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *BouncerClient) MapCreateBulk(slice any, setFunc func(*BouncerCreate, int)) *BouncerCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &BouncerCreateBulk{err: fmt.Errorf("calling to BouncerClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*BouncerCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &BouncerCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Bouncer. func (c *BouncerClient) Update() *BouncerUpdate { mutation := newBouncerMutation(c.config, OpUpdate) @@ -375,7 +545,7 @@ func (c *BouncerClient) DeleteOne(b *Bouncer) *BouncerDeleteOne { return c.DeleteOneID(b.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { builder := c.Delete().Where(bouncer.ID(id)) builder.mutation.id = &id @@ -387,6 +557,8 @@ func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne { func (c *BouncerClient) Query() *BouncerQuery { return &BouncerQuery{ config: c.config, + ctx: &QueryContext{Type: TypeBouncer}, + inters: c.Interceptors(), } } @@ -409,6 +581,26 @@ func (c *BouncerClient) Hooks() []Hook { return c.hooks.Bouncer } +// Interceptors returns the client interceptors. +func (c *BouncerClient) Interceptors() []Interceptor { + return c.inters.Bouncer +} + +func (c *BouncerClient) mutate(ctx context.Context, m *BouncerMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&BouncerCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&BouncerUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&BouncerUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&BouncerDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Bouncer mutation op: %q", m.Op()) + } +} + // ConfigItemClient is a client for the ConfigItem schema. type ConfigItemClient struct { config @@ -425,6 +617,12 @@ func (c *ConfigItemClient) Use(hooks ...Hook) { c.hooks.ConfigItem = append(c.hooks.ConfigItem, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `configitem.Intercept(f(g(h())))`. +func (c *ConfigItemClient) Intercept(interceptors ...Interceptor) { + c.inters.ConfigItem = append(c.inters.ConfigItem, interceptors...) +} + // Create returns a builder for creating a ConfigItem entity. func (c *ConfigItemClient) Create() *ConfigItemCreate { mutation := newConfigItemMutation(c.config, OpCreate) @@ -436,6 +634,21 @@ func (c *ConfigItemClient) CreateBulk(builders ...*ConfigItemCreate) *ConfigItem return &ConfigItemCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ConfigItemClient) MapCreateBulk(slice any, setFunc func(*ConfigItemCreate, int)) *ConfigItemCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ConfigItemCreateBulk{err: fmt.Errorf("calling to ConfigItemClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ConfigItemCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ConfigItemCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for ConfigItem. func (c *ConfigItemClient) Update() *ConfigItemUpdate { mutation := newConfigItemMutation(c.config, OpUpdate) @@ -465,7 +678,7 @@ func (c *ConfigItemClient) DeleteOne(ci *ConfigItem) *ConfigItemDeleteOne { return c.DeleteOneID(ci.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { builder := c.Delete().Where(configitem.ID(id)) builder.mutation.id = &id @@ -477,6 +690,8 @@ func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne { func (c *ConfigItemClient) Query() *ConfigItemQuery { return &ConfigItemQuery{ config: c.config, + ctx: &QueryContext{Type: TypeConfigItem}, + inters: c.Interceptors(), } } @@ -499,6 +714,26 @@ func (c *ConfigItemClient) Hooks() []Hook { return c.hooks.ConfigItem } +// Interceptors returns the client interceptors. +func (c *ConfigItemClient) Interceptors() []Interceptor { + return c.inters.ConfigItem +} + +func (c *ConfigItemClient) mutate(ctx context.Context, m *ConfigItemMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ConfigItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ConfigItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ConfigItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ConfigItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ConfigItem mutation op: %q", m.Op()) + } +} + // DecisionClient is a client for the Decision schema. type DecisionClient struct { config @@ -515,6 +750,12 @@ func (c *DecisionClient) Use(hooks ...Hook) { c.hooks.Decision = append(c.hooks.Decision, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `decision.Intercept(f(g(h())))`. +func (c *DecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.Decision = append(c.inters.Decision, interceptors...) +} + // Create returns a builder for creating a Decision entity. func (c *DecisionClient) Create() *DecisionCreate { mutation := newDecisionMutation(c.config, OpCreate) @@ -526,6 +767,21 @@ func (c *DecisionClient) CreateBulk(builders ...*DecisionCreate) *DecisionCreate return &DecisionCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DecisionClient) MapCreateBulk(slice any, setFunc func(*DecisionCreate, int)) *DecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DecisionCreateBulk{err: fmt.Errorf("calling to DecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DecisionCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Decision. func (c *DecisionClient) Update() *DecisionUpdate { mutation := newDecisionMutation(c.config, OpUpdate) @@ -555,7 +811,7 @@ func (c *DecisionClient) DeleteOne(d *Decision) *DecisionDeleteOne { return c.DeleteOneID(d.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { builder := c.Delete().Where(decision.ID(id)) builder.mutation.id = &id @@ -567,6 +823,8 @@ func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne { func (c *DecisionClient) Query() *DecisionQuery { return &DecisionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDecision}, + inters: c.Interceptors(), } } @@ -586,8 +844,8 @@ func (c *DecisionClient) GetX(ctx context.Context, id int) *Decision { // QueryOwner queries the owner edge of a Decision. func (c *DecisionClient) QueryOwner(d *Decision) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(decision.Table, decision.FieldID, id), @@ -605,6 +863,26 @@ func (c *DecisionClient) Hooks() []Hook { return c.hooks.Decision } +// Interceptors returns the client interceptors. +func (c *DecisionClient) Interceptors() []Interceptor { + return c.inters.Decision +} + +func (c *DecisionClient) mutate(ctx context.Context, m *DecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Decision mutation op: %q", m.Op()) + } +} + // EventClient is a client for the Event schema. type EventClient struct { config @@ -621,6 +899,12 @@ func (c *EventClient) Use(hooks ...Hook) { c.hooks.Event = append(c.hooks.Event, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `event.Intercept(f(g(h())))`. +func (c *EventClient) Intercept(interceptors ...Interceptor) { + c.inters.Event = append(c.inters.Event, interceptors...) +} + // Create returns a builder for creating a Event entity. func (c *EventClient) Create() *EventCreate { mutation := newEventMutation(c.config, OpCreate) @@ -632,6 +916,21 @@ func (c *EventClient) CreateBulk(builders ...*EventCreate) *EventCreateBulk { return &EventCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *EventClient) MapCreateBulk(slice any, setFunc func(*EventCreate, int)) *EventCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &EventCreateBulk{err: fmt.Errorf("calling to EventClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*EventCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &EventCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Event. func (c *EventClient) Update() *EventUpdate { mutation := newEventMutation(c.config, OpUpdate) @@ -661,7 +960,7 @@ func (c *EventClient) DeleteOne(e *Event) *EventDeleteOne { return c.DeleteOneID(e.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { builder := c.Delete().Where(event.ID(id)) builder.mutation.id = &id @@ -673,6 +972,8 @@ func (c *EventClient) DeleteOneID(id int) *EventDeleteOne { func (c *EventClient) Query() *EventQuery { return &EventQuery{ config: c.config, + ctx: &QueryContext{Type: TypeEvent}, + inters: c.Interceptors(), } } @@ -692,8 +993,8 @@ func (c *EventClient) GetX(ctx context.Context, id int) *Event { // QueryOwner queries the owner edge of a Event. func (c *EventClient) QueryOwner(e *Event) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := e.ID step := sqlgraph.NewStep( sqlgraph.From(event.Table, event.FieldID, id), @@ -711,6 +1012,159 @@ func (c *EventClient) Hooks() []Hook { return c.hooks.Event } +// Interceptors returns the client interceptors. +func (c *EventClient) Interceptors() []Interceptor { + return c.inters.Event +} + +func (c *EventClient) mutate(ctx context.Context, m *EventMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&EventCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&EventUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&EventUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&EventDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Event mutation op: %q", m.Op()) + } +} + +// LockClient is a client for the Lock schema. +type LockClient struct { + config +} + +// NewLockClient returns a client for the Lock from the given config. +func NewLockClient(c config) *LockClient { + return &LockClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `lock.Hooks(f(g(h())))`. +func (c *LockClient) Use(hooks ...Hook) { + c.hooks.Lock = append(c.hooks.Lock, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `lock.Intercept(f(g(h())))`. +func (c *LockClient) Intercept(interceptors ...Interceptor) { + c.inters.Lock = append(c.inters.Lock, interceptors...) +} + +// Create returns a builder for creating a Lock entity. +func (c *LockClient) Create() *LockCreate { + mutation := newLockMutation(c.config, OpCreate) + return &LockCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Lock entities. +func (c *LockClient) CreateBulk(builders ...*LockCreate) *LockCreateBulk { + return &LockCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *LockClient) MapCreateBulk(slice any, setFunc func(*LockCreate, int)) *LockCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &LockCreateBulk{err: fmt.Errorf("calling to LockClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*LockCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &LockCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Lock. +func (c *LockClient) Update() *LockUpdate { + mutation := newLockMutation(c.config, OpUpdate) + return &LockUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *LockClient) UpdateOne(l *Lock) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLock(l)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *LockClient) UpdateOneID(id int) *LockUpdateOne { + mutation := newLockMutation(c.config, OpUpdateOne, withLockID(id)) + return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Lock. +func (c *LockClient) Delete() *LockDelete { + mutation := newLockMutation(c.config, OpDelete) + return &LockDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *LockClient) DeleteOne(l *Lock) *LockDeleteOne { + return c.DeleteOneID(l.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *LockClient) DeleteOneID(id int) *LockDeleteOne { + builder := c.Delete().Where(lock.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &LockDeleteOne{builder} +} + +// Query returns a query builder for Lock. +func (c *LockClient) Query() *LockQuery { + return &LockQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeLock}, + inters: c.Interceptors(), + } +} + +// Get returns a Lock entity by its id. +func (c *LockClient) Get(ctx context.Context, id int) (*Lock, error) { + return c.Query().Where(lock.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *LockClient) GetX(ctx context.Context, id int) *Lock { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *LockClient) Hooks() []Hook { + return c.hooks.Lock +} + +// Interceptors returns the client interceptors. +func (c *LockClient) Interceptors() []Interceptor { + return c.inters.Lock +} + +func (c *LockClient) mutate(ctx context.Context, m *LockMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&LockCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&LockUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&LockDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Lock mutation op: %q", m.Op()) + } +} + // MachineClient is a client for the Machine schema. type MachineClient struct { config @@ -727,6 +1181,12 @@ func (c *MachineClient) Use(hooks ...Hook) { c.hooks.Machine = append(c.hooks.Machine, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `machine.Intercept(f(g(h())))`. +func (c *MachineClient) Intercept(interceptors ...Interceptor) { + c.inters.Machine = append(c.inters.Machine, interceptors...) +} + // Create returns a builder for creating a Machine entity. func (c *MachineClient) Create() *MachineCreate { mutation := newMachineMutation(c.config, OpCreate) @@ -738,6 +1198,21 @@ func (c *MachineClient) CreateBulk(builders ...*MachineCreate) *MachineCreateBul return &MachineCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MachineClient) MapCreateBulk(slice any, setFunc func(*MachineCreate, int)) *MachineCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MachineCreateBulk{err: fmt.Errorf("calling to MachineClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MachineCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MachineCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Machine. func (c *MachineClient) Update() *MachineUpdate { mutation := newMachineMutation(c.config, OpUpdate) @@ -767,7 +1242,7 @@ func (c *MachineClient) DeleteOne(m *Machine) *MachineDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { builder := c.Delete().Where(machine.ID(id)) builder.mutation.id = &id @@ -779,6 +1254,8 @@ func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne { func (c *MachineClient) Query() *MachineQuery { return &MachineQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMachine}, + inters: c.Interceptors(), } } @@ -798,8 +1275,8 @@ func (c *MachineClient) GetX(ctx context.Context, id int) *Machine { // QueryAlerts queries the alerts edge of a Machine. func (c *MachineClient) QueryAlerts(m *Machine) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(machine.Table, machine.FieldID, id), @@ -817,6 +1294,26 @@ func (c *MachineClient) Hooks() []Hook { return c.hooks.Machine } +// Interceptors returns the client interceptors. +func (c *MachineClient) Interceptors() []Interceptor { + return c.inters.Machine +} + +func (c *MachineClient) mutate(ctx context.Context, m *MachineMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MachineCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MachineUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MachineUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MachineDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Machine mutation op: %q", m.Op()) + } +} + // MetaClient is a client for the Meta schema. type MetaClient struct { config @@ -833,6 +1330,12 @@ func (c *MetaClient) Use(hooks ...Hook) { c.hooks.Meta = append(c.hooks.Meta, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `meta.Intercept(f(g(h())))`. +func (c *MetaClient) Intercept(interceptors ...Interceptor) { + c.inters.Meta = append(c.inters.Meta, interceptors...) +} + // Create returns a builder for creating a Meta entity. func (c *MetaClient) Create() *MetaCreate { mutation := newMetaMutation(c.config, OpCreate) @@ -844,6 +1347,21 @@ func (c *MetaClient) CreateBulk(builders ...*MetaCreate) *MetaCreateBulk { return &MetaCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetaClient) MapCreateBulk(slice any, setFunc func(*MetaCreate, int)) *MetaCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetaCreateBulk{err: fmt.Errorf("calling to MetaClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetaCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetaCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Meta. func (c *MetaClient) Update() *MetaUpdate { mutation := newMetaMutation(c.config, OpUpdate) @@ -873,7 +1391,7 @@ func (c *MetaClient) DeleteOne(m *Meta) *MetaDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { builder := c.Delete().Where(meta.ID(id)) builder.mutation.id = &id @@ -885,6 +1403,8 @@ func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne { func (c *MetaClient) Query() *MetaQuery { return &MetaQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMeta}, + inters: c.Interceptors(), } } @@ -904,8 +1424,8 @@ func (c *MetaClient) GetX(ctx context.Context, id int) *Meta { // QueryOwner queries the owner edge of a Meta. func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { - query := &AlertQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AlertClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(meta.Table, meta.FieldID, id), @@ -922,3 +1442,168 @@ func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery { func (c *MetaClient) Hooks() []Hook { return c.hooks.Meta } + +// Interceptors returns the client interceptors. +func (c *MetaClient) Interceptors() []Interceptor { + return c.inters.Meta +} + +func (c *MetaClient) mutate(ctx context.Context, m *MetaMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Meta mutation op: %q", m.Op()) + } +} + +// MetricClient is a client for the Metric schema. +type MetricClient struct { + config +} + +// NewMetricClient returns a client for the Metric from the given config. +func NewMetricClient(c config) *MetricClient { + return &MetricClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `metric.Hooks(f(g(h())))`. +func (c *MetricClient) Use(hooks ...Hook) { + c.hooks.Metric = append(c.hooks.Metric, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `metric.Intercept(f(g(h())))`. +func (c *MetricClient) Intercept(interceptors ...Interceptor) { + c.inters.Metric = append(c.inters.Metric, interceptors...) +} + +// Create returns a builder for creating a Metric entity. +func (c *MetricClient) Create() *MetricCreate { + mutation := newMetricMutation(c.config, OpCreate) + return &MetricCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Metric entities. +func (c *MetricClient) CreateBulk(builders ...*MetricCreate) *MetricCreateBulk { + return &MetricCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetricClient) MapCreateBulk(slice any, setFunc func(*MetricCreate, int)) *MetricCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetricCreateBulk{err: fmt.Errorf("calling to MetricClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetricCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetricCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Metric. +func (c *MetricClient) Update() *MetricUpdate { + mutation := newMetricMutation(c.config, OpUpdate) + return &MetricUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *MetricClient) UpdateOne(m *Metric) *MetricUpdateOne { + mutation := newMetricMutation(c.config, OpUpdateOne, withMetric(m)) + return &MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *MetricClient) UpdateOneID(id int) *MetricUpdateOne { + mutation := newMetricMutation(c.config, OpUpdateOne, withMetricID(id)) + return &MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Metric. +func (c *MetricClient) Delete() *MetricDelete { + mutation := newMetricMutation(c.config, OpDelete) + return &MetricDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *MetricClient) DeleteOne(m *Metric) *MetricDeleteOne { + return c.DeleteOneID(m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *MetricClient) DeleteOneID(id int) *MetricDeleteOne { + builder := c.Delete().Where(metric.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &MetricDeleteOne{builder} +} + +// Query returns a query builder for Metric. +func (c *MetricClient) Query() *MetricQuery { + return &MetricQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeMetric}, + inters: c.Interceptors(), + } +} + +// Get returns a Metric entity by its id. +func (c *MetricClient) Get(ctx context.Context, id int) (*Metric, error) { + return c.Query().Where(metric.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *MetricClient) GetX(ctx context.Context, id int) *Metric { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *MetricClient) Hooks() []Hook { + return c.hooks.Metric +} + +// Interceptors returns the client interceptors. +func (c *MetricClient) Interceptors() []Interceptor { + return c.inters.Metric +} + +func (c *MetricClient) mutate(ctx context.Context, m *MetricMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetricCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetricUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetricUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetricDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Metric mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, + Metric []ent.Hook + } + inters struct { + Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, + Metric []ent.Interceptor + } +) diff --git a/pkg/database/ent/config.go b/pkg/database/ent/config.go deleted file mode 100644 index 1a152809a32..00000000000 --- a/pkg/database/ent/config.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "entgo.io/ent" - "entgo.io/ent/dialect" -) - -// Option function to configure the client. -type Option func(*config) - -// Config is the configuration for the client and its builder. -type config struct { - // driver used for executing database requests. - driver dialect.Driver - // debug enable a debug logging. - debug bool - // log used for logging on debug mode. - log func(...any) - // hooks to execute on mutations. - hooks *hooks -} - -// hooks per client, for fast access. -type hooks struct { - Alert []ent.Hook - Bouncer []ent.Hook - ConfigItem []ent.Hook - Decision []ent.Hook - Event []ent.Hook - Machine []ent.Hook - Meta []ent.Hook -} - -// Options applies the options on the config object. -func (c *config) options(opts ...Option) { - for _, opt := range opts { - opt(c) - } - if c.debug { - c.driver = dialect.Debug(c.driver, c.log) - } -} - -// Debug enables debug logging on the ent.Driver. -func Debug() Option { - return func(c *config) { - c.debug = true - } -} - -// Log sets the logging function for debug mode. -func Log(fn func(...any)) Option { - return func(c *config) { - c.log = fn - } -} - -// Driver configures the client driver. -func Driver(driver dialect.Driver) Option { - return func(c *config) { - c.driver = driver - } -} diff --git a/pkg/database/ent/configitem.go b/pkg/database/ent/configitem.go index 615780dbacc..bdf23ef4948 100644 --- a/pkg/database/ent/configitem.go +++ b/pkg/database/ent/configitem.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) @@ -17,13 +18,14 @@ type ConfigItem struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at"` + CreatedAt time.Time `json:"created_at"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at"` + UpdatedAt time.Time `json:"updated_at"` // Name holds the value of the "name" field. Name string `json:"name"` // Value holds the value of the "value" field. - Value string `json:"value"` + Value string `json:"value"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -38,7 +40,7 @@ func (*ConfigItem) scanValues(columns []string) ([]any, error) { case configitem.FieldCreatedAt, configitem.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type ConfigItem", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -62,15 +64,13 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - ci.CreatedAt = new(time.Time) - *ci.CreatedAt = value.Time + ci.CreatedAt = value.Time } case configitem.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - ci.UpdatedAt = new(time.Time) - *ci.UpdatedAt = value.Time + ci.UpdatedAt = value.Time } case configitem.FieldName: if value, ok := values[i].(*sql.NullString); !ok { @@ -84,16 +84,24 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error { } else if value.Valid { ci.Value = value.String } + default: + ci.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the ConfigItem. +// This includes values selected through modifiers, order, etc. +func (ci *ConfigItem) GetValue(name string) (ent.Value, error) { + return ci.selectValues.Get(name) +} + // Update returns a builder for updating this ConfigItem. // Note that you need to call ConfigItem.Unwrap() before calling this method if this ConfigItem // was returned from a transaction, and the transaction was committed or rolled back. func (ci *ConfigItem) Update() *ConfigItemUpdateOne { - return (&ConfigItemClient{config: ci.config}).UpdateOne(ci) + return NewConfigItemClient(ci.config).UpdateOne(ci) } // Unwrap unwraps the ConfigItem entity that was returned from a transaction after it was closed, @@ -112,15 +120,11 @@ func (ci *ConfigItem) String() string { var builder strings.Builder builder.WriteString("ConfigItem(") builder.WriteString(fmt.Sprintf("id=%v, ", ci.ID)) - if v := ci.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(ci.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := ci.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(ci.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("name=") builder.WriteString(ci.Name) @@ -133,9 +137,3 @@ func (ci *ConfigItem) String() string { // ConfigItems is a parsable slice of ConfigItem. type ConfigItems []*ConfigItem - -func (ci ConfigItems) config(cfg config) { - for _i := range ci { - ci[_i].config = cfg - } -} diff --git a/pkg/database/ent/configitem/configitem.go b/pkg/database/ent/configitem/configitem.go index 80e93e4cc7e..611d81a3960 100644 --- a/pkg/database/ent/configitem/configitem.go +++ b/pkg/database/ent/configitem/configitem.go @@ -4,6 +4,8 @@ package configitem import ( "time" + + "entgo.io/ent/dialect/sql" ) const ( @@ -45,10 +47,36 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time ) + +// OrderOption defines the ordering options for the ConfigItem queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/pkg/database/ent/configitem/where.go b/pkg/database/ent/configitem/where.go index 6d06938a855..48ae792fd72 100644 --- a/pkg/database/ent/configitem/where.go +++ b/pkg/database/ent/configitem/where.go @@ -11,485 +11,290 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldUpdatedAt, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldName, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.ConfigItem { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.ConfigItem(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.ConfigItem(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. func And(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.ConfigItem(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.ConfigItem) predicate.ConfigItem { - return predicate.ConfigItem(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.ConfigItem(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/configitem_create.go b/pkg/database/ent/configitem_create.go index 736e6a50514..a2679927aee 100644 --- a/pkg/database/ent/configitem_create.go +++ b/pkg/database/ent/configitem_create.go @@ -67,50 +67,8 @@ func (cic *ConfigItemCreate) Mutation() *ConfigItemMutation { // Save creates the ConfigItem in the database. func (cic *ConfigItemCreate) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) cic.defaults() - if len(cic.hooks) == 0 { - if err = cic.check(); err != nil { - return nil, err - } - node, err = cic.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = cic.check(); err != nil { - return nil, err - } - cic.mutation = mutation - if node, err = cic.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(cic.hooks) - 1; i >= 0; i-- { - if cic.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cic.hooks[i](mut) - } - v, err := mut.Mutate(ctx, cic.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, cic.sqlSave, cic.mutation, cic.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -149,6 +107,12 @@ func (cic *ConfigItemCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (cic *ConfigItemCreate) check() error { + if _, ok := cic.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ConfigItem.created_at"`)} + } + if _, ok := cic.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ConfigItem.updated_at"`)} + } if _, ok := cic.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ConfigItem.name"`)} } @@ -159,6 +123,9 @@ func (cic *ConfigItemCreate) check() error { } func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { + if err := cic.check(); err != nil { + return nil, err + } _node, _spec := cic.createSpec() if err := sqlgraph.CreateNode(ctx, cic.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -168,50 +135,30 @@ func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + cic.mutation.id = &_node.ID + cic.mutation.done = true return _node, nil } func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { var ( _node = &ConfigItem{config: cic.config} - _spec = &sqlgraph.CreateSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) ) if value, ok := cic.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := cic.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := cic.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldName, field.TypeString, value) _node.Name = value } if value, ok := cic.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec @@ -220,11 +167,15 @@ func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) { // ConfigItemCreateBulk is the builder for creating many ConfigItem entities in bulk. type ConfigItemCreateBulk struct { config + err error builders []*ConfigItemCreate } // Save creates the ConfigItem entities in the database. func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, error) { + if cicb.err != nil { + return nil, cicb.err + } specs := make([]*sqlgraph.CreateSpec, len(cicb.builders)) nodes := make([]*ConfigItem, len(cicb.builders)) mutators := make([]Mutator, len(cicb.builders)) @@ -241,8 +192,8 @@ func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, erro return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, cicb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/configitem_delete.go b/pkg/database/ent/configitem_delete.go index 223fa9eefbf..a5dc811f60d 100644 --- a/pkg/database/ent/configitem_delete.go +++ b/pkg/database/ent/configitem_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (cid *ConfigItemDelete) Where(ps ...predicate.ConfigItem) *ConfigItemDelete // Exec executes the deletion query and returns how many vertices were deleted. func (cid *ConfigItemDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(cid.hooks) == 0 { - affected, err = cid.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - cid.mutation = mutation - affected, err = cid.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(cid.hooks) - 1; i >= 0; i-- { - if cid.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = cid.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, cid.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, cid.sqlExec, cid.mutation, cid.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (cid *ConfigItemDelete) ExecX(ctx context.Context) int { } func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := cid.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + cid.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type ConfigItemDeleteOne struct { cid *ConfigItemDelete } +// Where appends a list predicates to the ConfigItemDelete builder. +func (cido *ConfigItemDeleteOne) Where(ps ...predicate.ConfigItem) *ConfigItemDeleteOne { + cido.cid.mutation.Where(ps...) + return cido +} + // Exec executes the deletion query. func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { n, err := cido.cid.Exec(ctx) @@ -111,5 +82,7 @@ func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (cido *ConfigItemDeleteOne) ExecX(ctx context.Context) { - cido.cid.ExecX(ctx) + if err := cido.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/configitem_query.go b/pkg/database/ent/configitem_query.go index 6c9e6732a9b..f68b8953ddb 100644 --- a/pkg/database/ent/configitem_query.go +++ b/pkg/database/ent/configitem_query.go @@ -17,11 +17,9 @@ import ( // ConfigItemQuery is the builder for querying ConfigItem entities. type ConfigItemQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []configitem.OrderOption + inters []Interceptor predicates []predicate.ConfigItem // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,27 +32,27 @@ func (ciq *ConfigItemQuery) Where(ps ...predicate.ConfigItem) *ConfigItemQuery { return ciq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (ciq *ConfigItemQuery) Limit(limit int) *ConfigItemQuery { - ciq.limit = &limit + ciq.ctx.Limit = &limit return ciq } -// Offset adds an offset step to the query. +// Offset to start from. func (ciq *ConfigItemQuery) Offset(offset int) *ConfigItemQuery { - ciq.offset = &offset + ciq.ctx.Offset = &offset return ciq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (ciq *ConfigItemQuery) Unique(unique bool) *ConfigItemQuery { - ciq.unique = &unique + ciq.ctx.Unique = &unique return ciq } -// Order adds an order step to the query. -func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { +// Order specifies how the records should be ordered. +func (ciq *ConfigItemQuery) Order(o ...configitem.OrderOption) *ConfigItemQuery { ciq.order = append(ciq.order, o...) return ciq } @@ -62,7 +60,7 @@ func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery { // First returns the first ConfigItem entity from the query. // Returns a *NotFoundError when no ConfigItem was found. func (ciq *ConfigItemQuery) First(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(1).All(ctx) + nodes, err := ciq.Limit(1).All(setContextOp(ctx, ciq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (ciq *ConfigItemQuery) FirstX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no ConfigItem ID was found. func (ciq *ConfigItemQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(1).IDs(ctx); err != nil { + if ids, err = ciq.Limit(1).IDs(setContextOp(ctx, ciq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (ciq *ConfigItemQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one ConfigItem entity is found. // Returns a *NotFoundError when no ConfigItem entities are found. func (ciq *ConfigItemQuery) Only(ctx context.Context) (*ConfigItem, error) { - nodes, err := ciq.Limit(2).All(ctx) + nodes, err := ciq.Limit(2).All(setContextOp(ctx, ciq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (ciq *ConfigItemQuery) OnlyX(ctx context.Context) *ConfigItem { // Returns a *NotFoundError when no entities are found. func (ciq *ConfigItemQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = ciq.Limit(2).IDs(ctx); err != nil { + if ids, err = ciq.Limit(2).IDs(setContextOp(ctx, ciq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (ciq *ConfigItemQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of ConfigItems. func (ciq *ConfigItemQuery) All(ctx context.Context) ([]*ConfigItem, error) { + ctx = setContextOp(ctx, ciq.ctx, "All") if err := ciq.prepareQuery(ctx); err != nil { return nil, err } - return ciq.sqlAll(ctx) + qr := querierAll[[]*ConfigItem, *ConfigItemQuery]() + return withInterceptors[[]*ConfigItem](ctx, ciq, qr, ciq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (ciq *ConfigItemQuery) AllX(ctx context.Context) []*ConfigItem { } // IDs executes the query and returns a list of ConfigItem IDs. -func (ciq *ConfigItemQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { +func (ciq *ConfigItemQuery) IDs(ctx context.Context) (ids []int, err error) { + if ciq.ctx.Unique == nil && ciq.path != nil { + ciq.Unique(true) + } + ctx = setContextOp(ctx, ciq.ctx, "IDs") + if err = ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (ciq *ConfigItemQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (ciq *ConfigItemQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, ciq.ctx, "Count") if err := ciq.prepareQuery(ctx); err != nil { return 0, err } - return ciq.sqlCount(ctx) + return withInterceptors[int](ctx, ciq, querierCount[*ConfigItemQuery](), ciq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (ciq *ConfigItemQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (ciq *ConfigItemQuery) Exist(ctx context.Context) (bool, error) { - if err := ciq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, ciq.ctx, "Exist") + switch _, err := ciq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return ciq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { } return &ConfigItemQuery{ config: ciq.config, - limit: ciq.limit, - offset: ciq.offset, - order: append([]OrderFunc{}, ciq.order...), + ctx: ciq.ctx.Clone(), + order: append([]configitem.OrderOption{}, ciq.order...), + inters: append([]Interceptor{}, ciq.inters...), predicates: append([]predicate.ConfigItem{}, ciq.predicates...), // clone intermediate query. - sql: ciq.sql.Clone(), - path: ciq.path, - unique: ciq.unique, + sql: ciq.sql.Clone(), + path: ciq.path, } } @@ -262,16 +270,11 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemGroupBy { - grbuild := &ConfigItemGroupBy{config: ciq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := ciq.prepareQuery(ctx); err != nil { - return nil, err - } - return ciq.sqlQuery(ctx), nil - } + ciq.ctx.Fields = append([]string{field}, fields...) + grbuild := &ConfigItemGroupBy{build: ciq} + grbuild.flds = &ciq.ctx.Fields grbuild.label = configitem.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,15 +291,30 @@ func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemG // Select(configitem.FieldCreatedAt). // Scan(ctx, &v) func (ciq *ConfigItemQuery) Select(fields ...string) *ConfigItemSelect { - ciq.fields = append(ciq.fields, fields...) - selbuild := &ConfigItemSelect{ConfigItemQuery: ciq} - selbuild.label = configitem.Label - selbuild.flds, selbuild.scan = &ciq.fields, selbuild.Scan - return selbuild + ciq.ctx.Fields = append(ciq.ctx.Fields, fields...) + sbuild := &ConfigItemSelect{ConfigItemQuery: ciq} + sbuild.label = configitem.Label + sbuild.flds, sbuild.scan = &ciq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ConfigItemSelect configured with the given aggregations. +func (ciq *ConfigItemQuery) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + return ciq.Select().Aggregate(fns...) } func (ciq *ConfigItemQuery) prepareQuery(ctx context.Context) error { - for _, f := range ciq.fields { + for _, inter := range ciq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, ciq); err != nil { + return err + } + } + } + for _, f := range ciq.ctx.Fields { if !configitem.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -338,41 +356,22 @@ func (ciq *ConfigItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]* func (ciq *ConfigItemQuery) sqlCount(ctx context.Context) (int, error) { _spec := ciq.querySpec() - _spec.Node.Columns = ciq.fields - if len(ciq.fields) > 0 { - _spec.Unique = ciq.unique != nil && *ciq.unique + _spec.Node.Columns = ciq.ctx.Fields + if len(ciq.ctx.Fields) > 0 { + _spec.Unique = ciq.ctx.Unique != nil && *ciq.ctx.Unique } return sqlgraph.CountNodes(ctx, ciq.driver, _spec) } -func (ciq *ConfigItemQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := ciq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - From: ciq.sql, - Unique: true, - } - if unique := ciq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) + _spec.From = ciq.sql + if unique := ciq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if ciq.path != nil { + _spec.Unique = true } - if fields := ciq.fields; len(fields) > 0 { + if fields := ciq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, configitem.FieldID) for i := range fields { @@ -388,10 +387,10 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := ciq.order; len(ps) > 0 { @@ -407,7 +406,7 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec { func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(ciq.driver.Dialect()) t1 := builder.Table(configitem.Table) - columns := ciq.fields + columns := ciq.ctx.Fields if len(columns) == 0 { columns = configitem.Columns } @@ -416,7 +415,7 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = ciq.sql selector.Select(selector.Columns(columns...)...) } - if ciq.unique != nil && *ciq.unique { + if ciq.ctx.Unique != nil && *ciq.ctx.Unique { selector.Distinct() } for _, p := range ciq.predicates { @@ -425,12 +424,12 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range ciq.order { p(selector) } - if offset := ciq.offset; offset != nil { + if offset := ciq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := ciq.limit; limit != nil { + if limit := ciq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -438,13 +437,8 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector { // ConfigItemGroupBy is the group-by builder for ConfigItem entities. type ConfigItemGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *ConfigItemQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -453,74 +447,77 @@ func (cigb *ConfigItemGroupBy) Aggregate(fns ...AggregateFunc) *ConfigItemGroupB return cigb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (cigb *ConfigItemGroupBy) Scan(ctx context.Context, v any) error { - query, err := cigb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, cigb.build.ctx, "GroupBy") + if err := cigb.build.prepareQuery(ctx); err != nil { return err } - cigb.sql = query - return cigb.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemGroupBy](ctx, cigb.build, cigb, cigb.build.inters, v) } -func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range cigb.fields { - if !configitem.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(cigb.fns)) + for _, fn := range cigb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*cigb.flds)+len(cigb.fns)) + for _, f := range *cigb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := cigb.sqlQuery() + selector.GroupBy(selector.Columns(*cigb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := cigb.driver.Query(ctx, query, args, rows); err != nil { + if err := cigb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (cigb *ConfigItemGroupBy) sqlQuery() *sql.Selector { - selector := cigb.sql.Select() - aggregation := make([]string, 0, len(cigb.fns)) - for _, fn := range cigb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(cigb.fields)+len(cigb.fns)) - for _, f := range cigb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(cigb.fields...)...) -} - // ConfigItemSelect is the builder for selecting fields of ConfigItem entities. type ConfigItemSelect struct { *ConfigItemQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (cis *ConfigItemSelect) Aggregate(fns ...AggregateFunc) *ConfigItemSelect { + cis.fns = append(cis.fns, fns...) + return cis } // Scan applies the selector query and scans the result into the given value. func (cis *ConfigItemSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, cis.ctx, "Select") if err := cis.prepareQuery(ctx); err != nil { return err } - cis.sql = cis.ConfigItemQuery.sqlQuery(ctx) - return cis.sqlScan(ctx, v) + return scanWithInterceptors[*ConfigItemQuery, *ConfigItemSelect](ctx, cis.ConfigItemQuery, cis, cis.inters, v) } -func (cis *ConfigItemSelect) sqlScan(ctx context.Context, v any) error { +func (cis *ConfigItemSelect) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(cis.fns)) + for _, fn := range cis.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*cis.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := cis.sql.Query() + query, args := selector.Query() if err := cis.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/configitem_update.go b/pkg/database/ent/configitem_update.go index e591347a0c3..82309459e76 100644 --- a/pkg/database/ent/configitem_update.go +++ b/pkg/database/ent/configitem_update.go @@ -28,42 +28,26 @@ func (ciu *ConfigItemUpdate) Where(ps ...predicate.ConfigItem) *ConfigItemUpdate return ciu } -// SetCreatedAt sets the "created_at" field. -func (ciu *ConfigItemUpdate) SetCreatedAt(t time.Time) *ConfigItemUpdate { - ciu.mutation.SetCreatedAt(t) - return ciu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (ciu *ConfigItemUpdate) ClearCreatedAt() *ConfigItemUpdate { - ciu.mutation.ClearCreatedAt() - return ciu -} - // SetUpdatedAt sets the "updated_at" field. func (ciu *ConfigItemUpdate) SetUpdatedAt(t time.Time) *ConfigItemUpdate { ciu.mutation.SetUpdatedAt(t) return ciu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (ciu *ConfigItemUpdate) ClearUpdatedAt() *ConfigItemUpdate { - ciu.mutation.ClearUpdatedAt() - return ciu -} - -// SetName sets the "name" field. -func (ciu *ConfigItemUpdate) SetName(s string) *ConfigItemUpdate { - ciu.mutation.SetName(s) - return ciu -} - // SetValue sets the "value" field. func (ciu *ConfigItemUpdate) SetValue(s string) *ConfigItemUpdate { ciu.mutation.SetValue(s) return ciu } +// SetNillableValue sets the "value" field if the given value is not nil. +func (ciu *ConfigItemUpdate) SetNillableValue(s *string) *ConfigItemUpdate { + if s != nil { + ciu.SetValue(*s) + } + return ciu +} + // Mutation returns the ConfigItemMutation object of the builder. func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation { return ciu.mutation @@ -71,35 +55,8 @@ func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (ciu *ConfigItemUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) ciu.defaults() - if len(ciu.hooks) == 0 { - affected, err = ciu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciu.mutation = mutation - affected, err = ciu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(ciu.hooks) - 1; i >= 0; i-- { - if ciu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ciu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ciu.sqlSave, ciu.mutation, ciu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -126,27 +83,14 @@ func (ciu *ConfigItemUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (ciu *ConfigItemUpdate) defaults() { - if _, ok := ciu.mutation.CreatedAt(); !ok && !ciu.mutation.CreatedAtCleared() { - v := configitem.UpdateDefaultCreatedAt() - ciu.mutation.SetCreatedAt(v) - } - if _, ok := ciu.mutation.UpdatedAt(); !ok && !ciu.mutation.UpdatedAtCleared() { + if _, ok := ciu.mutation.UpdatedAt(); !ok { v := configitem.UpdateDefaultUpdatedAt() ciu.mutation.SetUpdatedAt(v) } } func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) if ps := ciu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -154,45 +98,11 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := ciu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - } - if ciu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) - } if value, ok := ciu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - } - if ciu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) - } - if value, ok := ciu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if value, ok := ciu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, ciu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -202,6 +112,7 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + ciu.mutation.done = true return n, nil } @@ -213,47 +124,37 @@ type ConfigItemUpdateOne struct { mutation *ConfigItemMutation } -// SetCreatedAt sets the "created_at" field. -func (ciuo *ConfigItemUpdateOne) SetCreatedAt(t time.Time) *ConfigItemUpdateOne { - ciuo.mutation.SetCreatedAt(t) - return ciuo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (ciuo *ConfigItemUpdateOne) ClearCreatedAt() *ConfigItemUpdateOne { - ciuo.mutation.ClearCreatedAt() - return ciuo -} - // SetUpdatedAt sets the "updated_at" field. func (ciuo *ConfigItemUpdateOne) SetUpdatedAt(t time.Time) *ConfigItemUpdateOne { ciuo.mutation.SetUpdatedAt(t) return ciuo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (ciuo *ConfigItemUpdateOne) ClearUpdatedAt() *ConfigItemUpdateOne { - ciuo.mutation.ClearUpdatedAt() - return ciuo -} - -// SetName sets the "name" field. -func (ciuo *ConfigItemUpdateOne) SetName(s string) *ConfigItemUpdateOne { - ciuo.mutation.SetName(s) - return ciuo -} - // SetValue sets the "value" field. func (ciuo *ConfigItemUpdateOne) SetValue(s string) *ConfigItemUpdateOne { ciuo.mutation.SetValue(s) return ciuo } +// SetNillableValue sets the "value" field if the given value is not nil. +func (ciuo *ConfigItemUpdateOne) SetNillableValue(s *string) *ConfigItemUpdateOne { + if s != nil { + ciuo.SetValue(*s) + } + return ciuo +} + // Mutation returns the ConfigItemMutation object of the builder. func (ciuo *ConfigItemUpdateOne) Mutation() *ConfigItemMutation { return ciuo.mutation } +// Where appends a list predicates to the ConfigItemUpdate builder. +func (ciuo *ConfigItemUpdateOne) Where(ps ...predicate.ConfigItem) *ConfigItemUpdateOne { + ciuo.mutation.Where(ps...) + return ciuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigItemUpdateOne { @@ -263,41 +164,8 @@ func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigI // Save executes the query and returns the updated ConfigItem entity. func (ciuo *ConfigItemUpdateOne) Save(ctx context.Context) (*ConfigItem, error) { - var ( - err error - node *ConfigItem - ) ciuo.defaults() - if len(ciuo.hooks) == 0 { - node, err = ciuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ciuo.mutation = mutation - node, err = ciuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(ciuo.hooks) - 1; i >= 0; i-- { - if ciuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ciuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ciuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*ConfigItem) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ciuo.sqlSave, ciuo.mutation, ciuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -324,27 +192,14 @@ func (ciuo *ConfigItemUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (ciuo *ConfigItemUpdateOne) defaults() { - if _, ok := ciuo.mutation.CreatedAt(); !ok && !ciuo.mutation.CreatedAtCleared() { - v := configitem.UpdateDefaultCreatedAt() - ciuo.mutation.SetCreatedAt(v) - } - if _, ok := ciuo.mutation.UpdatedAt(); !ok && !ciuo.mutation.UpdatedAtCleared() { + if _, ok := ciuo.mutation.UpdatedAt(); !ok { v := configitem.UpdateDefaultUpdatedAt() ciuo.mutation.SetUpdatedAt(v) } } func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: configitem.Table, - Columns: configitem.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: configitem.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt)) id, ok := ciuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ConfigItem.id" for update`)} @@ -369,45 +224,11 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } } } - if value, ok := ciuo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldCreatedAt, - }) - } - if ciuo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldCreatedAt, - }) - } if value, ok := ciuo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: configitem.FieldUpdatedAt, - }) - } - if ciuo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: configitem.FieldUpdatedAt, - }) - } - if value, ok := ciuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldName, - }) + _spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value) } if value, ok := ciuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: configitem.FieldValue, - }) + _spec.SetField(configitem.FieldValue, field.TypeString, value) } _node = &ConfigItem{config: ciuo.config} _spec.Assign = _node.assignValues @@ -420,5 +241,6 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem } return nil, err } + ciuo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/context.go b/pkg/database/ent/context.go deleted file mode 100644 index 7811bfa2349..00000000000 --- a/pkg/database/ent/context.go +++ /dev/null @@ -1,33 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "context" -) - -type clientCtxKey struct{} - -// FromContext returns a Client stored inside a context, or nil if there isn't one. -func FromContext(ctx context.Context) *Client { - c, _ := ctx.Value(clientCtxKey{}).(*Client) - return c -} - -// NewContext returns a new context with the given Client attached. -func NewContext(parent context.Context, c *Client) context.Context { - return context.WithValue(parent, clientCtxKey{}, c) -} - -type txCtxKey struct{} - -// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. -func TxFromContext(ctx context.Context) *Tx { - tx, _ := ctx.Value(txCtxKey{}).(*Tx) - return tx -} - -// NewTxContext returns a new context with the given Tx attached. -func NewTxContext(parent context.Context, tx *Tx) context.Context { - return context.WithValue(parent, txCtxKey{}, tx) -} diff --git a/pkg/database/ent/decision.go b/pkg/database/ent/decision.go index c969e576724..4a6dc728509 100644 --- a/pkg/database/ent/decision.go +++ b/pkg/database/ent/decision.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -18,9 +19,9 @@ type Decision struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Until holds the value of the "until" field. Until *time.Time `json:"until,omitempty"` // Scenario holds the value of the "scenario" field. @@ -51,7 +52,8 @@ type Decision struct { AlertDecisions int `json:"alert_decisions,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the DecisionQuery when eager-loading is set. - Edges DecisionEdges `json:"edges"` + Edges DecisionEdges `json:"edges"` + selectValues sql.SelectValues } // DecisionEdges holds the relations/edges for other nodes in the graph. @@ -66,12 +68,10 @@ type DecisionEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e DecisionEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -90,7 +90,7 @@ func (*Decision) scanValues(columns []string) ([]any, error) { case decision.FieldCreatedAt, decision.FieldUpdatedAt, decision.FieldUntil: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Decision", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -114,15 +114,13 @@ func (d *Decision) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - d.CreatedAt = new(time.Time) - *d.CreatedAt = value.Time + d.CreatedAt = value.Time } case decision.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - d.UpdatedAt = new(time.Time) - *d.UpdatedAt = value.Time + d.UpdatedAt = value.Time } case decision.FieldUntil: if value, ok := values[i].(*sql.NullTime); !ok { @@ -209,21 +207,29 @@ func (d *Decision) assignValues(columns []string, values []any) error { } else if value.Valid { d.AlertDecisions = int(value.Int64) } + default: + d.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Decision. +// This includes values selected through modifiers, order, etc. +func (d *Decision) GetValue(name string) (ent.Value, error) { + return d.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Decision entity. func (d *Decision) QueryOwner() *AlertQuery { - return (&DecisionClient{config: d.config}).QueryOwner(d) + return NewDecisionClient(d.config).QueryOwner(d) } // Update returns a builder for updating this Decision. // Note that you need to call Decision.Unwrap() before calling this method if this Decision // was returned from a transaction, and the transaction was committed or rolled back. func (d *Decision) Update() *DecisionUpdateOne { - return (&DecisionClient{config: d.config}).UpdateOne(d) + return NewDecisionClient(d.config).UpdateOne(d) } // Unwrap unwraps the Decision entity that was returned from a transaction after it was closed, @@ -242,15 +248,11 @@ func (d *Decision) String() string { var builder strings.Builder builder.WriteString("Decision(") builder.WriteString(fmt.Sprintf("id=%v, ", d.ID)) - if v := d.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(d.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := d.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(d.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") if v := d.Until; v != nil { builder.WriteString("until=") @@ -301,9 +303,3 @@ func (d *Decision) String() string { // Decisions is a parsable slice of Decision. type Decisions []*Decision - -func (d Decisions) config(cfg config) { - for _i := range d { - d[_i].config = cfg - } -} diff --git a/pkg/database/ent/decision/decision.go b/pkg/database/ent/decision/decision.go index a0012d940a8..38c9721db48 100644 --- a/pkg/database/ent/decision/decision.go +++ b/pkg/database/ent/decision/decision.go @@ -4,6 +4,9 @@ package decision import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -90,8 +93,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -99,3 +100,105 @@ var ( // DefaultSimulated holds the default value on creation for the "simulated" field. DefaultSimulated bool ) + +// OrderOption defines the ordering options for the Decision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUntil orders the results by the until field. +func ByUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUntil, opts...).ToFunc() +} + +// ByScenario orders the results by the scenario field. +func ByScenario(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenario, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByStartIP orders the results by the start_ip field. +func ByStartIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartIP, opts...).ToFunc() +} + +// ByEndIP orders the results by the end_ip field. +func ByEndIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndIP, opts...).ToFunc() +} + +// ByStartSuffix orders the results by the start_suffix field. +func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartSuffix, opts...).ToFunc() +} + +// ByEndSuffix orders the results by the end_suffix field. +func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndSuffix, opts...).ToFunc() +} + +// ByIPSize orders the results by the ip_size field. +func ByIPSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPSize, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByOrigin orders the results by the origin field. +func ByOrigin(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOrigin, opts...).ToFunc() +} + +// BySimulated orders the results by the simulated field. +func BySimulated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulated, opts...).ToFunc() +} + +// ByUUID orders the results by the uuid field. +func ByUUID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUUID, opts...).ToFunc() +} + +// ByAlertDecisions orders the results by the alert_decisions field. +func ByAlertDecisions(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertDecisions, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/decision/where.go b/pkg/database/ent/decision/where.go index 18716a4a7c1..99a1889e63e 100644 --- a/pkg/database/ent/decision/where.go +++ b/pkg/database/ent/decision/where.go @@ -12,1481 +12,947 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Decision(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // Until applies equality check predicate on the "until" field. It's identical to UntilEQ. func Until(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // Scenario applies equality check predicate on the "scenario" field. It's identical to ScenarioEQ. func Scenario(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // StartIP applies equality check predicate on the "start_ip" field. It's identical to StartIPEQ. func StartIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // EndIP applies equality check predicate on the "end_ip" field. It's identical to EndIPEQ. func EndIP(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // StartSuffix applies equality check predicate on the "start_suffix" field. It's identical to StartSuffixEQ. func StartSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // EndSuffix applies equality check predicate on the "end_suffix" field. It's identical to EndSuffixEQ. func EndSuffix(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // IPSize applies equality check predicate on the "ip_size" field. It's identical to IPSizeEQ. func IPSize(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. func Scope(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // Origin applies equality check predicate on the "origin" field. It's identical to OriginEQ. func Origin(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // Simulated applies equality check predicate on the "simulated" field. It's identical to SimulatedEQ. func Simulated(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // UUID applies equality check predicate on the "uuid" field. It's identical to UUIDEQ. func UUID(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // AlertDecisions applies equality check predicate on the "alert_decisions" field. It's identical to AlertDecisionsEQ. func AlertDecisions(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Decision(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Decision(sql.FieldLTE(FieldUpdatedAt, v)) } // UntilEQ applies the EQ predicate on the "until" field. func UntilEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUntil, v)) } // UntilNEQ applies the NEQ predicate on the "until" field. func UntilNEQ(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUntil, v)) } // UntilIn applies the In predicate on the "until" field. func UntilIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUntil, vs...)) } // UntilNotIn applies the NotIn predicate on the "until" field. func UntilNotIn(vs ...time.Time) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUntil), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUntil, vs...)) } // UntilGT applies the GT predicate on the "until" field. func UntilGT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUntil, v)) } // UntilGTE applies the GTE predicate on the "until" field. func UntilGTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUntil, v)) } // UntilLT applies the LT predicate on the "until" field. func UntilLT(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUntil, v)) } // UntilLTE applies the LTE predicate on the "until" field. func UntilLTE(v time.Time) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUntil), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUntil, v)) } // UntilIsNil applies the IsNil predicate on the "until" field. func UntilIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUntil)) } // UntilNotNil applies the NotNil predicate on the "until" field. func UntilNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUntil))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUntil)) } // ScenarioEQ applies the EQ predicate on the "scenario" field. func ScenarioEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScenario, v)) } // ScenarioNEQ applies the NEQ predicate on the "scenario" field. func ScenarioNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScenario, v)) } // ScenarioIn applies the In predicate on the "scenario" field. func ScenarioIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScenario, vs...)) } // ScenarioNotIn applies the NotIn predicate on the "scenario" field. func ScenarioNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenario), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScenario, vs...)) } // ScenarioGT applies the GT predicate on the "scenario" field. func ScenarioGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScenario, v)) } // ScenarioGTE applies the GTE predicate on the "scenario" field. func ScenarioGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScenario, v)) } // ScenarioLT applies the LT predicate on the "scenario" field. func ScenarioLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScenario, v)) } // ScenarioLTE applies the LTE predicate on the "scenario" field. func ScenarioLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScenario, v)) } // ScenarioContains applies the Contains predicate on the "scenario" field. func ScenarioContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScenario, v)) } // ScenarioHasPrefix applies the HasPrefix predicate on the "scenario" field. func ScenarioHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScenario, v)) } // ScenarioHasSuffix applies the HasSuffix predicate on the "scenario" field. func ScenarioHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScenario, v)) } // ScenarioEqualFold applies the EqualFold predicate on the "scenario" field. func ScenarioEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScenario, v)) } // ScenarioContainsFold applies the ContainsFold predicate on the "scenario" field. func ScenarioContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenario), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScenario, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldType, v)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldType, v)) } // StartIPEQ applies the EQ predicate on the "start_ip" field. func StartIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartIP, v)) } // StartIPNEQ applies the NEQ predicate on the "start_ip" field. func StartIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartIP, v)) } // StartIPIn applies the In predicate on the "start_ip" field. func StartIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartIP, vs...)) } // StartIPNotIn applies the NotIn predicate on the "start_ip" field. func StartIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartIP, vs...)) } // StartIPGT applies the GT predicate on the "start_ip" field. func StartIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartIP, v)) } // StartIPGTE applies the GTE predicate on the "start_ip" field. func StartIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartIP, v)) } // StartIPLT applies the LT predicate on the "start_ip" field. func StartIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartIP, v)) } // StartIPLTE applies the LTE predicate on the "start_ip" field. func StartIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartIP, v)) } // StartIPIsNil applies the IsNil predicate on the "start_ip" field. func StartIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartIP)) } // StartIPNotNil applies the NotNil predicate on the "start_ip" field. func StartIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartIP)) } // EndIPEQ applies the EQ predicate on the "end_ip" field. func EndIPEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndIP, v)) } // EndIPNEQ applies the NEQ predicate on the "end_ip" field. func EndIPNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndIP, v)) } // EndIPIn applies the In predicate on the "end_ip" field. func EndIPIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndIP, vs...)) } // EndIPNotIn applies the NotIn predicate on the "end_ip" field. func EndIPNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndIP), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndIP, vs...)) } // EndIPGT applies the GT predicate on the "end_ip" field. func EndIPGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndIP, v)) } // EndIPGTE applies the GTE predicate on the "end_ip" field. func EndIPGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndIP, v)) } // EndIPLT applies the LT predicate on the "end_ip" field. func EndIPLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndIP, v)) } // EndIPLTE applies the LTE predicate on the "end_ip" field. func EndIPLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndIP), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndIP, v)) } // EndIPIsNil applies the IsNil predicate on the "end_ip" field. func EndIPIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndIP)) } // EndIPNotNil applies the NotNil predicate on the "end_ip" field. func EndIPNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndIP))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndIP)) } // StartSuffixEQ applies the EQ predicate on the "start_suffix" field. func StartSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldStartSuffix, v)) } // StartSuffixNEQ applies the NEQ predicate on the "start_suffix" field. func StartSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldStartSuffix, v)) } // StartSuffixIn applies the In predicate on the "start_suffix" field. func StartSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldStartSuffix, vs...)) } // StartSuffixNotIn applies the NotIn predicate on the "start_suffix" field. func StartSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStartSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldStartSuffix, vs...)) } // StartSuffixGT applies the GT predicate on the "start_suffix" field. func StartSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldStartSuffix, v)) } // StartSuffixGTE applies the GTE predicate on the "start_suffix" field. func StartSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldStartSuffix, v)) } // StartSuffixLT applies the LT predicate on the "start_suffix" field. func StartSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldStartSuffix, v)) } // StartSuffixLTE applies the LTE predicate on the "start_suffix" field. func StartSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStartSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldStartSuffix, v)) } // StartSuffixIsNil applies the IsNil predicate on the "start_suffix" field. func StartSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldStartSuffix)) } // StartSuffixNotNil applies the NotNil predicate on the "start_suffix" field. func StartSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStartSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldStartSuffix)) } // EndSuffixEQ applies the EQ predicate on the "end_suffix" field. func EndSuffixEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldEndSuffix, v)) } // EndSuffixNEQ applies the NEQ predicate on the "end_suffix" field. func EndSuffixNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldEndSuffix, v)) } // EndSuffixIn applies the In predicate on the "end_suffix" field. func EndSuffixIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldEndSuffix, vs...)) } // EndSuffixNotIn applies the NotIn predicate on the "end_suffix" field. func EndSuffixNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEndSuffix), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldEndSuffix, vs...)) } // EndSuffixGT applies the GT predicate on the "end_suffix" field. func EndSuffixGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGT(FieldEndSuffix, v)) } // EndSuffixGTE applies the GTE predicate on the "end_suffix" field. func EndSuffixGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldEndSuffix, v)) } // EndSuffixLT applies the LT predicate on the "end_suffix" field. func EndSuffixLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLT(FieldEndSuffix, v)) } // EndSuffixLTE applies the LTE predicate on the "end_suffix" field. func EndSuffixLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEndSuffix), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldEndSuffix, v)) } // EndSuffixIsNil applies the IsNil predicate on the "end_suffix" field. func EndSuffixIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldIsNull(FieldEndSuffix)) } // EndSuffixNotNil applies the NotNil predicate on the "end_suffix" field. func EndSuffixNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldEndSuffix))) - }) + return predicate.Decision(sql.FieldNotNull(FieldEndSuffix)) } // IPSizeEQ applies the EQ predicate on the "ip_size" field. func IPSizeEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldIPSize, v)) } // IPSizeNEQ applies the NEQ predicate on the "ip_size" field. func IPSizeNEQ(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldIPSize, v)) } // IPSizeIn applies the In predicate on the "ip_size" field. func IPSizeIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldIPSize, vs...)) } // IPSizeNotIn applies the NotIn predicate on the "ip_size" field. func IPSizeNotIn(vs ...int64) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIPSize), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldIPSize, vs...)) } // IPSizeGT applies the GT predicate on the "ip_size" field. func IPSizeGT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGT(FieldIPSize, v)) } // IPSizeGTE applies the GTE predicate on the "ip_size" field. func IPSizeGTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldIPSize, v)) } // IPSizeLT applies the LT predicate on the "ip_size" field. func IPSizeLT(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLT(FieldIPSize, v)) } // IPSizeLTE applies the LTE predicate on the "ip_size" field. func IPSizeLTE(v int64) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIPSize), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldIPSize, v)) } // IPSizeIsNil applies the IsNil predicate on the "ip_size" field. func IPSizeIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldIsNull(FieldIPSize)) } // IPSizeNotNil applies the NotNil predicate on the "ip_size" field. func IPSizeNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldIPSize))) - }) + return predicate.Decision(sql.FieldNotNull(FieldIPSize)) } // ScopeEQ applies the EQ predicate on the "scope" field. func ScopeEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldScope, v)) } // ScopeNEQ applies the NEQ predicate on the "scope" field. func ScopeNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldScope, v)) } // ScopeIn applies the In predicate on the "scope" field. func ScopeIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldScope, vs...)) } // ScopeNotIn applies the NotIn predicate on the "scope" field. func ScopeNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScope), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldScope, vs...)) } // ScopeGT applies the GT predicate on the "scope" field. func ScopeGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGT(FieldScope, v)) } // ScopeGTE applies the GTE predicate on the "scope" field. func ScopeGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldScope, v)) } // ScopeLT applies the LT predicate on the "scope" field. func ScopeLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLT(FieldScope, v)) } // ScopeLTE applies the LTE predicate on the "scope" field. func ScopeLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldScope, v)) } // ScopeContains applies the Contains predicate on the "scope" field. func ScopeContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContains(FieldScope, v)) } // ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. func ScopeHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldScope, v)) } // ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. func ScopeHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldScope, v)) } // ScopeEqualFold applies the EqualFold predicate on the "scope" field. func ScopeEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldScope, v)) } // ScopeContainsFold applies the ContainsFold predicate on the "scope" field. func ScopeContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScope), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldScope, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldValue, v)) } // OriginEQ applies the EQ predicate on the "origin" field. func OriginEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldOrigin, v)) } // OriginNEQ applies the NEQ predicate on the "origin" field. func OriginNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldOrigin, v)) } // OriginIn applies the In predicate on the "origin" field. func OriginIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldOrigin, vs...)) } // OriginNotIn applies the NotIn predicate on the "origin" field. func OriginNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldOrigin), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldOrigin, vs...)) } // OriginGT applies the GT predicate on the "origin" field. func OriginGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGT(FieldOrigin, v)) } // OriginGTE applies the GTE predicate on the "origin" field. func OriginGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldOrigin, v)) } // OriginLT applies the LT predicate on the "origin" field. func OriginLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLT(FieldOrigin, v)) } // OriginLTE applies the LTE predicate on the "origin" field. func OriginLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldOrigin, v)) } // OriginContains applies the Contains predicate on the "origin" field. func OriginContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContains(FieldOrigin, v)) } // OriginHasPrefix applies the HasPrefix predicate on the "origin" field. func OriginHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldOrigin, v)) } // OriginHasSuffix applies the HasSuffix predicate on the "origin" field. func OriginHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldOrigin, v)) } // OriginEqualFold applies the EqualFold predicate on the "origin" field. func OriginEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldOrigin, v)) } // OriginContainsFold applies the ContainsFold predicate on the "origin" field. func OriginContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldOrigin), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldOrigin, v)) } // SimulatedEQ applies the EQ predicate on the "simulated" field. func SimulatedEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldSimulated, v)) } // SimulatedNEQ applies the NEQ predicate on the "simulated" field. func SimulatedNEQ(v bool) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSimulated), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldSimulated, v)) } // UUIDEQ applies the EQ predicate on the "uuid" field. func UUIDEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldUUID, v)) } // UUIDNEQ applies the NEQ predicate on the "uuid" field. func UUIDNEQ(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldUUID, v)) } // UUIDIn applies the In predicate on the "uuid" field. func UUIDIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldUUID, vs...)) } // UUIDNotIn applies the NotIn predicate on the "uuid" field. func UUIDNotIn(vs ...string) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUUID), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldUUID, vs...)) } // UUIDGT applies the GT predicate on the "uuid" field. func UUIDGT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGT(FieldUUID, v)) } // UUIDGTE applies the GTE predicate on the "uuid" field. func UUIDGTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldGTE(FieldUUID, v)) } // UUIDLT applies the LT predicate on the "uuid" field. func UUIDLT(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLT(FieldUUID, v)) } // UUIDLTE applies the LTE predicate on the "uuid" field. func UUIDLTE(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldLTE(FieldUUID, v)) } // UUIDContains applies the Contains predicate on the "uuid" field. func UUIDContains(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContains(FieldUUID, v)) } // UUIDHasPrefix applies the HasPrefix predicate on the "uuid" field. func UUIDHasPrefix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasPrefix(FieldUUID, v)) } // UUIDHasSuffix applies the HasSuffix predicate on the "uuid" field. func UUIDHasSuffix(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldHasSuffix(FieldUUID, v)) } // UUIDIsNil applies the IsNil predicate on the "uuid" field. func UUIDIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldIsNull(FieldUUID)) } // UUIDNotNil applies the NotNil predicate on the "uuid" field. func UUIDNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUUID))) - }) + return predicate.Decision(sql.FieldNotNull(FieldUUID)) } // UUIDEqualFold applies the EqualFold predicate on the "uuid" field. func UUIDEqualFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldEqualFold(FieldUUID, v)) } // UUIDContainsFold applies the ContainsFold predicate on the "uuid" field. func UUIDContainsFold(v string) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldUUID), v)) - }) + return predicate.Decision(sql.FieldContainsFold(FieldUUID, v)) } // AlertDecisionsEQ applies the EQ predicate on the "alert_decisions" field. func AlertDecisionsEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldEQ(FieldAlertDecisions, v)) } // AlertDecisionsNEQ applies the NEQ predicate on the "alert_decisions" field. func AlertDecisionsNEQ(v int) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertDecisions), v)) - }) + return predicate.Decision(sql.FieldNEQ(FieldAlertDecisions, v)) } // AlertDecisionsIn applies the In predicate on the "alert_decisions" field. func AlertDecisionsIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldIn(FieldAlertDecisions, vs...)) } // AlertDecisionsNotIn applies the NotIn predicate on the "alert_decisions" field. func AlertDecisionsNotIn(vs ...int) predicate.Decision { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertDecisions), v...)) - }) + return predicate.Decision(sql.FieldNotIn(FieldAlertDecisions, vs...)) } // AlertDecisionsIsNil applies the IsNil predicate on the "alert_decisions" field. func AlertDecisionsIsNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldIsNull(FieldAlertDecisions)) } // AlertDecisionsNotNil applies the NotNil predicate on the "alert_decisions" field. func AlertDecisionsNotNil() predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertDecisions))) - }) + return predicate.Decision(sql.FieldNotNull(FieldAlertDecisions)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -1494,7 +960,6 @@ func HasOwner() predicate.Decision { return predicate.Decision(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1504,11 +969,7 @@ func HasOwner() predicate.Decision { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { return predicate.Decision(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1519,32 +980,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Decision { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Decision(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Decision) predicate.Decision { - return predicate.Decision(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Decision(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/decision_create.go b/pkg/database/ent/decision_create.go index 64238cb7003..f30d5452120 100644 --- a/pkg/database/ent/decision_create.go +++ b/pkg/database/ent/decision_create.go @@ -231,50 +231,8 @@ func (dc *DecisionCreate) Mutation() *DecisionMutation { // Save creates the Decision in the database. func (dc *DecisionCreate) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) dc.defaults() - if len(dc.hooks) == 0 { - if err = dc.check(); err != nil { - return nil, err - } - node, err = dc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dc.check(); err != nil { - return nil, err - } - dc.mutation = mutation - if node, err = dc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dc.hooks) - 1; i >= 0; i-- { - if dc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, dc.sqlSave, dc.mutation, dc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -317,6 +275,12 @@ func (dc *DecisionCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (dc *DecisionCreate) check() error { + if _, ok := dc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Decision.created_at"`)} + } + if _, ok := dc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Decision.updated_at"`)} + } if _, ok := dc.mutation.Scenario(); !ok { return &ValidationError{Name: "scenario", err: errors.New(`ent: missing required field "Decision.scenario"`)} } @@ -339,6 +303,9 @@ func (dc *DecisionCreate) check() error { } func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { + if err := dc.check(); err != nil { + return nil, err + } _node, _spec := dc.createSpec() if err := sqlgraph.CreateNode(ctx, dc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -348,138 +315,74 @@ func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + dc.mutation.id = &_node.ID + dc.mutation.done = true return _node, nil } func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { var ( _node = &Decision{config: dc.config} - _spec = &sqlgraph.CreateSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) ) if value, ok := dc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(decision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := dc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := dc.mutation.Until(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) _node.Until = &value } if value, ok := dc.mutation.Scenario(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) + _spec.SetField(decision.FieldScenario, field.TypeString, value) _node.Scenario = value } if value, ok := dc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) + _spec.SetField(decision.FieldType, field.TypeString, value) _node.Type = value } if value, ok := dc.mutation.StartIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.SetField(decision.FieldStartIP, field.TypeInt64, value) _node.StartIP = value } if value, ok := dc.mutation.EndIP(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.SetField(decision.FieldEndIP, field.TypeInt64, value) _node.EndIP = value } if value, ok := dc.mutation.StartSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value) _node.StartSuffix = value } if value, ok := dc.mutation.EndSuffix(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value) _node.EndSuffix = value } if value, ok := dc.mutation.IPSize(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.SetField(decision.FieldIPSize, field.TypeInt64, value) _node.IPSize = value } if value, ok := dc.mutation.Scope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) + _spec.SetField(decision.FieldScope, field.TypeString, value) _node.Scope = value } if value, ok := dc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) + _spec.SetField(decision.FieldValue, field.TypeString, value) _node.Value = value } if value, ok := dc.mutation.Origin(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) + _spec.SetField(decision.FieldOrigin, field.TypeString, value) _node.Origin = value } if value, ok := dc.mutation.Simulated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) + _spec.SetField(decision.FieldSimulated, field.TypeBool, value) _node.Simulated = value } if value, ok := dc.mutation.UUID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.SetField(decision.FieldUUID, field.TypeString, value) _node.UUID = value } if nodes := dc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -490,10 +393,7 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -508,11 +408,15 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) { // DecisionCreateBulk is the builder for creating many Decision entities in bulk. type DecisionCreateBulk struct { config + err error builders []*DecisionCreate } // Save creates the Decision entities in the database. func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { + if dcb.err != nil { + return nil, dcb.err + } specs := make([]*sqlgraph.CreateSpec, len(dcb.builders)) nodes := make([]*Decision, len(dcb.builders)) mutators := make([]Mutator, len(dcb.builders)) @@ -529,8 +433,8 @@ func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, dcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/decision_delete.go b/pkg/database/ent/decision_delete.go index 24b494b113e..35bb8767283 100644 --- a/pkg/database/ent/decision_delete.go +++ b/pkg/database/ent/decision_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (dd *DecisionDelete) Where(ps ...predicate.Decision) *DecisionDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (dd *DecisionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dd.hooks) == 0 { - affected, err = dd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dd.mutation = mutation - affected, err = dd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dd.hooks) - 1; i >= 0; i-- { - if dd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, dd.sqlExec, dd.mutation, dd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dd *DecisionDelete) ExecX(ctx context.Context) int { } func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := dd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DecisionDeleteOne struct { dd *DecisionDelete } +// Where appends a list predicates to the DecisionDelete builder. +func (ddo *DecisionDeleteOne) Where(ps ...predicate.Decision) *DecisionDeleteOne { + ddo.dd.mutation.Where(ps...) + return ddo +} + // Exec executes the deletion query. func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { n, err := ddo.dd.Exec(ctx) @@ -111,5 +82,7 @@ func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ddo *DecisionDeleteOne) ExecX(ctx context.Context) { - ddo.dd.ExecX(ctx) + if err := ddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/decision_query.go b/pkg/database/ent/decision_query.go index 91aebded968..b050a4d9649 100644 --- a/pkg/database/ent/decision_query.go +++ b/pkg/database/ent/decision_query.go @@ -18,11 +18,9 @@ import ( // DecisionQuery is the builder for querying Decision entities. type DecisionQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []decision.OrderOption + inters []Interceptor predicates []predicate.Decision withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (dq *DecisionQuery) Where(ps ...predicate.Decision) *DecisionQuery { return dq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dq *DecisionQuery) Limit(limit int) *DecisionQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } -// Offset adds an offset step to the query. +// Offset to start from. func (dq *DecisionQuery) Offset(offset int) *DecisionQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DecisionQuery) Unique(unique bool) *DecisionQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } -// Order adds an order step to the query. -func (dq *DecisionQuery) Order(o ...OrderFunc) *DecisionQuery { +// Order specifies how the records should be ordered. +func (dq *DecisionQuery) Order(o ...decision.OrderOption) *DecisionQuery { dq.order = append(dq.order, o...) return dq } // QueryOwner chains the current query on the "owner" edge. func (dq *DecisionQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (dq *DecisionQuery) QueryOwner() *AlertQuery { // First returns the first Decision entity from the query. // Returns a *NotFoundError when no Decision was found. func (dq *DecisionQuery) First(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(1).All(ctx) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (dq *DecisionQuery) FirstX(ctx context.Context) *Decision { // Returns a *NotFoundError when no Decision ID was found. func (dq *DecisionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(1).IDs(ctx); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (dq *DecisionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Decision entity is found. // Returns a *NotFoundError when no Decision entities are found. func (dq *DecisionQuery) Only(ctx context.Context) (*Decision, error) { - nodes, err := dq.Limit(2).All(ctx) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (dq *DecisionQuery) OnlyX(ctx context.Context) *Decision { // Returns a *NotFoundError when no entities are found. func (dq *DecisionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(2).IDs(ctx); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (dq *DecisionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Decisions. func (dq *DecisionQuery) All(ctx context.Context) ([]*Decision, error) { + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } - return dq.sqlAll(ctx) + qr := querierAll[[]*Decision, *DecisionQuery]() + return withInterceptors[[]*Decision](ctx, dq, qr, dq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (dq *DecisionQuery) AllX(ctx context.Context) []*Decision { } // IDs executes the query and returns a list of Decision IDs. -func (dq *DecisionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { +func (dq *DecisionQuery) IDs(ctx context.Context) (ids []int, err error) { + if dq.ctx.Unique == nil && dq.path != nil { + dq.Unique(true) + } + ctx = setContextOp(ctx, dq.ctx, "IDs") + if err = dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (dq *DecisionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dq *DecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } - return dq.sqlCount(ctx) + return withInterceptors[int](ctx, dq, querierCount[*DecisionQuery](), dq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (dq *DecisionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DecisionQuery) Exist(ctx context.Context) (bool, error) { - if err := dq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dq.ctx, "Exist") + switch _, err := dq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return dq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (dq *DecisionQuery) Clone() *DecisionQuery { } return &DecisionQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, - order: append([]OrderFunc{}, dq.order...), + ctx: dq.ctx.Clone(), + order: append([]decision.OrderOption{}, dq.order...), + inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Decision{}, dq.predicates...), withOwner: dq.withOwner.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { - query := &AlertQuery{config: dq.config} + query := (&AlertClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupBy { - grbuild := &DecisionGroupBy{config: dq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dq.prepareQuery(ctx); err != nil { - return nil, err - } - return dq.sqlQuery(ctx), nil - } + dq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DecisionGroupBy{build: dq} + grbuild.flds = &dq.ctx.Fields grbuild.label = decision.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupB // Select(decision.FieldCreatedAt). // Scan(ctx, &v) func (dq *DecisionQuery) Select(fields ...string) *DecisionSelect { - dq.fields = append(dq.fields, fields...) - selbuild := &DecisionSelect{DecisionQuery: dq} - selbuild.label = decision.Label - selbuild.flds, selbuild.scan = &dq.fields, selbuild.Scan - return selbuild + dq.ctx.Fields = append(dq.ctx.Fields, fields...) + sbuild := &DecisionSelect{DecisionQuery: dq} + sbuild.label = decision.Label + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DecisionSelect configured with the given aggregations. +func (dq *DecisionQuery) Aggregate(fns ...AggregateFunc) *DecisionSelect { + return dq.Select().Aggregate(fns...) } func (dq *DecisionQuery) prepareQuery(ctx context.Context) error { - for _, f := range dq.fields { + for _, inter := range dq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dq); err != nil { + return err + } + } + } + for _, f := range dq.ctx.Fields { if !decision.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) { _spec := dq.querySpec() - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } -func (dq *DecisionQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := dq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - From: dq.sql, - Unique: true, - } - if unique := dq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) + _spec.From = dq.sql + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dq.path != nil { + _spec.Unique = true } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, decision.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if dq.withOwner != nil { + _spec.Node.AddColumnOnce(decision.FieldAlertDecisions) + } } if ps := dq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(decision.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = decision.Columns } @@ -489,7 +494,7 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -498,12 +503,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { // DecisionGroupBy is the group-by builder for Decision entities. type DecisionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DecisionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (dgb *DecisionGroupBy) Aggregate(fns ...AggregateFunc) *DecisionGroupBy { return dgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (dgb *DecisionGroupBy) Scan(ctx context.Context, v any) error { - query, err := dgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") + if err := dgb.build.prepareQuery(ctx); err != nil { return err } - dgb.sql = query - return dgb.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionGroupBy](ctx, dgb.build, dgb, dgb.build.inters, v) } -func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range dgb.fields { - if !decision.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dgb.fns)) + for _, fn := range dgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dgb.flds)+len(dgb.fns)) + for _, f := range *dgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dgb.sqlQuery() + selector.GroupBy(selector.Columns(*dgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dgb *DecisionGroupBy) sqlQuery() *sql.Selector { - selector := dgb.sql.Select() - aggregation := make([]string, 0, len(dgb.fns)) - for _, fn := range dgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dgb.fields)+len(dgb.fns)) - for _, f := range dgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dgb.fields...)...) -} - // DecisionSelect is the builder for selecting fields of Decision entities. type DecisionSelect struct { *DecisionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ds *DecisionSelect) Aggregate(fns ...AggregateFunc) *DecisionSelect { + ds.fns = append(ds.fns, fns...) + return ds } // Scan applies the selector query and scans the result into the given value. func (ds *DecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } - ds.sql = ds.DecisionQuery.sqlQuery(ctx) - return ds.sqlScan(ctx, v) + return scanWithInterceptors[*DecisionQuery, *DecisionSelect](ctx, ds.DecisionQuery, ds, ds.inters, v) } -func (ds *DecisionSelect) sqlScan(ctx context.Context, v any) error { +func (ds *DecisionSelect) sqlScan(ctx context.Context, root *DecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ds.fns)) + for _, fn := range ds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ds.sql.Query() + query, args := selector.Query() if err := ds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/decision_update.go b/pkg/database/ent/decision_update.go index 64b40871eca..68d0eb4ace7 100644 --- a/pkg/database/ent/decision_update.go +++ b/pkg/database/ent/decision_update.go @@ -29,30 +29,12 @@ func (du *DecisionUpdate) Where(ps ...predicate.Decision) *DecisionUpdate { return du } -// SetCreatedAt sets the "created_at" field. -func (du *DecisionUpdate) SetCreatedAt(t time.Time) *DecisionUpdate { - du.mutation.SetCreatedAt(t) - return du -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (du *DecisionUpdate) ClearCreatedAt() *DecisionUpdate { - du.mutation.ClearCreatedAt() - return du -} - // SetUpdatedAt sets the "updated_at" field. func (du *DecisionUpdate) SetUpdatedAt(t time.Time) *DecisionUpdate { du.mutation.SetUpdatedAt(t) return du } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (du *DecisionUpdate) ClearUpdatedAt() *DecisionUpdate { - du.mutation.ClearUpdatedAt() - return du -} - // SetUntil sets the "until" field. func (du *DecisionUpdate) SetUntil(t time.Time) *DecisionUpdate { du.mutation.SetUntil(t) @@ -73,205 +55,6 @@ func (du *DecisionUpdate) ClearUntil() *DecisionUpdate { return du } -// SetScenario sets the "scenario" field. -func (du *DecisionUpdate) SetScenario(s string) *DecisionUpdate { - du.mutation.SetScenario(s) - return du -} - -// SetType sets the "type" field. -func (du *DecisionUpdate) SetType(s string) *DecisionUpdate { - du.mutation.SetType(s) - return du -} - -// SetStartIP sets the "start_ip" field. -func (du *DecisionUpdate) SetStartIP(i int64) *DecisionUpdate { - du.mutation.ResetStartIP() - du.mutation.SetStartIP(i) - return du -} - -// SetNillableStartIP sets the "start_ip" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableStartIP(i *int64) *DecisionUpdate { - if i != nil { - du.SetStartIP(*i) - } - return du -} - -// AddStartIP adds i to the "start_ip" field. -func (du *DecisionUpdate) AddStartIP(i int64) *DecisionUpdate { - du.mutation.AddStartIP(i) - return du -} - -// ClearStartIP clears the value of the "start_ip" field. -func (du *DecisionUpdate) ClearStartIP() *DecisionUpdate { - du.mutation.ClearStartIP() - return du -} - -// SetEndIP sets the "end_ip" field. -func (du *DecisionUpdate) SetEndIP(i int64) *DecisionUpdate { - du.mutation.ResetEndIP() - du.mutation.SetEndIP(i) - return du -} - -// SetNillableEndIP sets the "end_ip" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableEndIP(i *int64) *DecisionUpdate { - if i != nil { - du.SetEndIP(*i) - } - return du -} - -// AddEndIP adds i to the "end_ip" field. -func (du *DecisionUpdate) AddEndIP(i int64) *DecisionUpdate { - du.mutation.AddEndIP(i) - return du -} - -// ClearEndIP clears the value of the "end_ip" field. -func (du *DecisionUpdate) ClearEndIP() *DecisionUpdate { - du.mutation.ClearEndIP() - return du -} - -// SetStartSuffix sets the "start_suffix" field. -func (du *DecisionUpdate) SetStartSuffix(i int64) *DecisionUpdate { - du.mutation.ResetStartSuffix() - du.mutation.SetStartSuffix(i) - return du -} - -// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableStartSuffix(i *int64) *DecisionUpdate { - if i != nil { - du.SetStartSuffix(*i) - } - return du -} - -// AddStartSuffix adds i to the "start_suffix" field. -func (du *DecisionUpdate) AddStartSuffix(i int64) *DecisionUpdate { - du.mutation.AddStartSuffix(i) - return du -} - -// ClearStartSuffix clears the value of the "start_suffix" field. -func (du *DecisionUpdate) ClearStartSuffix() *DecisionUpdate { - du.mutation.ClearStartSuffix() - return du -} - -// SetEndSuffix sets the "end_suffix" field. -func (du *DecisionUpdate) SetEndSuffix(i int64) *DecisionUpdate { - du.mutation.ResetEndSuffix() - du.mutation.SetEndSuffix(i) - return du -} - -// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableEndSuffix(i *int64) *DecisionUpdate { - if i != nil { - du.SetEndSuffix(*i) - } - return du -} - -// AddEndSuffix adds i to the "end_suffix" field. -func (du *DecisionUpdate) AddEndSuffix(i int64) *DecisionUpdate { - du.mutation.AddEndSuffix(i) - return du -} - -// ClearEndSuffix clears the value of the "end_suffix" field. -func (du *DecisionUpdate) ClearEndSuffix() *DecisionUpdate { - du.mutation.ClearEndSuffix() - return du -} - -// SetIPSize sets the "ip_size" field. -func (du *DecisionUpdate) SetIPSize(i int64) *DecisionUpdate { - du.mutation.ResetIPSize() - du.mutation.SetIPSize(i) - return du -} - -// SetNillableIPSize sets the "ip_size" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableIPSize(i *int64) *DecisionUpdate { - if i != nil { - du.SetIPSize(*i) - } - return du -} - -// AddIPSize adds i to the "ip_size" field. -func (du *DecisionUpdate) AddIPSize(i int64) *DecisionUpdate { - du.mutation.AddIPSize(i) - return du -} - -// ClearIPSize clears the value of the "ip_size" field. -func (du *DecisionUpdate) ClearIPSize() *DecisionUpdate { - du.mutation.ClearIPSize() - return du -} - -// SetScope sets the "scope" field. -func (du *DecisionUpdate) SetScope(s string) *DecisionUpdate { - du.mutation.SetScope(s) - return du -} - -// SetValue sets the "value" field. -func (du *DecisionUpdate) SetValue(s string) *DecisionUpdate { - du.mutation.SetValue(s) - return du -} - -// SetOrigin sets the "origin" field. -func (du *DecisionUpdate) SetOrigin(s string) *DecisionUpdate { - du.mutation.SetOrigin(s) - return du -} - -// SetSimulated sets the "simulated" field. -func (du *DecisionUpdate) SetSimulated(b bool) *DecisionUpdate { - du.mutation.SetSimulated(b) - return du -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableSimulated(b *bool) *DecisionUpdate { - if b != nil { - du.SetSimulated(*b) - } - return du -} - -// SetUUID sets the "uuid" field. -func (du *DecisionUpdate) SetUUID(s string) *DecisionUpdate { - du.mutation.SetUUID(s) - return du -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (du *DecisionUpdate) SetNillableUUID(s *string) *DecisionUpdate { - if s != nil { - du.SetUUID(*s) - } - return du -} - -// ClearUUID clears the value of the "uuid" field. -func (du *DecisionUpdate) ClearUUID() *DecisionUpdate { - du.mutation.ClearUUID() - return du -} - // SetAlertDecisions sets the "alert_decisions" field. func (du *DecisionUpdate) SetAlertDecisions(i int) *DecisionUpdate { du.mutation.SetAlertDecisions(i) @@ -324,35 +107,8 @@ func (du *DecisionUpdate) ClearOwner() *DecisionUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (du *DecisionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) du.defaults() - if len(du.hooks) == 0 { - affected, err = du.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - du.mutation = mutation - affected, err = du.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(du.hooks) - 1; i >= 0; i-- { - if du.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = du.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, du.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, du.sqlSave, du.mutation, du.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -379,27 +135,14 @@ func (du *DecisionUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (du *DecisionUpdate) defaults() { - if _, ok := du.mutation.CreatedAt(); !ok && !du.mutation.CreatedAtCleared() { - v := decision.UpdateDefaultCreatedAt() - du.mutation.SetCreatedAt(v) - } - if _, ok := du.mutation.UpdatedAt(); !ok && !du.mutation.UpdatedAtCleared() { + if _, ok := du.mutation.UpdatedAt(); !ok { v := decision.UpdateDefaultUpdatedAt() du.mutation.SetUpdatedAt(v) } } func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) if ps := du.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -407,199 +150,32 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := du.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - } - if du.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) - } if value, ok := du.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - } - if du.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if value, ok := du.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if du.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) - } - if value, ok := du.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) - } - if value, ok := du.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) - } - if value, ok := du.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) - } - if value, ok := du.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if du.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) - } - if value, ok := du.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) - } - if value, ok := du.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if du.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) - } - if value, ok := du.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := du.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if du.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := du.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := du.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if du.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := du.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) - } - if value, ok := du.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if du.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) - } - if value, ok := du.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) - } - if value, ok := du.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) - } - if value, ok := du.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) - } - if value, ok := du.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) - } - if value, ok := du.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if du.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if du.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -609,10 +185,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -625,10 +198,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -644,6 +214,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + du.mutation.done = true return n, nil } @@ -655,30 +226,12 @@ type DecisionUpdateOne struct { mutation *DecisionMutation } -// SetCreatedAt sets the "created_at" field. -func (duo *DecisionUpdateOne) SetCreatedAt(t time.Time) *DecisionUpdateOne { - duo.mutation.SetCreatedAt(t) - return duo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (duo *DecisionUpdateOne) ClearCreatedAt() *DecisionUpdateOne { - duo.mutation.ClearCreatedAt() - return duo -} - // SetUpdatedAt sets the "updated_at" field. func (duo *DecisionUpdateOne) SetUpdatedAt(t time.Time) *DecisionUpdateOne { duo.mutation.SetUpdatedAt(t) return duo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (duo *DecisionUpdateOne) ClearUpdatedAt() *DecisionUpdateOne { - duo.mutation.ClearUpdatedAt() - return duo -} - // SetUntil sets the "until" field. func (duo *DecisionUpdateOne) SetUntil(t time.Time) *DecisionUpdateOne { duo.mutation.SetUntil(t) @@ -699,205 +252,6 @@ func (duo *DecisionUpdateOne) ClearUntil() *DecisionUpdateOne { return duo } -// SetScenario sets the "scenario" field. -func (duo *DecisionUpdateOne) SetScenario(s string) *DecisionUpdateOne { - duo.mutation.SetScenario(s) - return duo -} - -// SetType sets the "type" field. -func (duo *DecisionUpdateOne) SetType(s string) *DecisionUpdateOne { - duo.mutation.SetType(s) - return duo -} - -// SetStartIP sets the "start_ip" field. -func (duo *DecisionUpdateOne) SetStartIP(i int64) *DecisionUpdateOne { - duo.mutation.ResetStartIP() - duo.mutation.SetStartIP(i) - return duo -} - -// SetNillableStartIP sets the "start_ip" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableStartIP(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetStartIP(*i) - } - return duo -} - -// AddStartIP adds i to the "start_ip" field. -func (duo *DecisionUpdateOne) AddStartIP(i int64) *DecisionUpdateOne { - duo.mutation.AddStartIP(i) - return duo -} - -// ClearStartIP clears the value of the "start_ip" field. -func (duo *DecisionUpdateOne) ClearStartIP() *DecisionUpdateOne { - duo.mutation.ClearStartIP() - return duo -} - -// SetEndIP sets the "end_ip" field. -func (duo *DecisionUpdateOne) SetEndIP(i int64) *DecisionUpdateOne { - duo.mutation.ResetEndIP() - duo.mutation.SetEndIP(i) - return duo -} - -// SetNillableEndIP sets the "end_ip" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableEndIP(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetEndIP(*i) - } - return duo -} - -// AddEndIP adds i to the "end_ip" field. -func (duo *DecisionUpdateOne) AddEndIP(i int64) *DecisionUpdateOne { - duo.mutation.AddEndIP(i) - return duo -} - -// ClearEndIP clears the value of the "end_ip" field. -func (duo *DecisionUpdateOne) ClearEndIP() *DecisionUpdateOne { - duo.mutation.ClearEndIP() - return duo -} - -// SetStartSuffix sets the "start_suffix" field. -func (duo *DecisionUpdateOne) SetStartSuffix(i int64) *DecisionUpdateOne { - duo.mutation.ResetStartSuffix() - duo.mutation.SetStartSuffix(i) - return duo -} - -// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableStartSuffix(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetStartSuffix(*i) - } - return duo -} - -// AddStartSuffix adds i to the "start_suffix" field. -func (duo *DecisionUpdateOne) AddStartSuffix(i int64) *DecisionUpdateOne { - duo.mutation.AddStartSuffix(i) - return duo -} - -// ClearStartSuffix clears the value of the "start_suffix" field. -func (duo *DecisionUpdateOne) ClearStartSuffix() *DecisionUpdateOne { - duo.mutation.ClearStartSuffix() - return duo -} - -// SetEndSuffix sets the "end_suffix" field. -func (duo *DecisionUpdateOne) SetEndSuffix(i int64) *DecisionUpdateOne { - duo.mutation.ResetEndSuffix() - duo.mutation.SetEndSuffix(i) - return duo -} - -// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableEndSuffix(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetEndSuffix(*i) - } - return duo -} - -// AddEndSuffix adds i to the "end_suffix" field. -func (duo *DecisionUpdateOne) AddEndSuffix(i int64) *DecisionUpdateOne { - duo.mutation.AddEndSuffix(i) - return duo -} - -// ClearEndSuffix clears the value of the "end_suffix" field. -func (duo *DecisionUpdateOne) ClearEndSuffix() *DecisionUpdateOne { - duo.mutation.ClearEndSuffix() - return duo -} - -// SetIPSize sets the "ip_size" field. -func (duo *DecisionUpdateOne) SetIPSize(i int64) *DecisionUpdateOne { - duo.mutation.ResetIPSize() - duo.mutation.SetIPSize(i) - return duo -} - -// SetNillableIPSize sets the "ip_size" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableIPSize(i *int64) *DecisionUpdateOne { - if i != nil { - duo.SetIPSize(*i) - } - return duo -} - -// AddIPSize adds i to the "ip_size" field. -func (duo *DecisionUpdateOne) AddIPSize(i int64) *DecisionUpdateOne { - duo.mutation.AddIPSize(i) - return duo -} - -// ClearIPSize clears the value of the "ip_size" field. -func (duo *DecisionUpdateOne) ClearIPSize() *DecisionUpdateOne { - duo.mutation.ClearIPSize() - return duo -} - -// SetScope sets the "scope" field. -func (duo *DecisionUpdateOne) SetScope(s string) *DecisionUpdateOne { - duo.mutation.SetScope(s) - return duo -} - -// SetValue sets the "value" field. -func (duo *DecisionUpdateOne) SetValue(s string) *DecisionUpdateOne { - duo.mutation.SetValue(s) - return duo -} - -// SetOrigin sets the "origin" field. -func (duo *DecisionUpdateOne) SetOrigin(s string) *DecisionUpdateOne { - duo.mutation.SetOrigin(s) - return duo -} - -// SetSimulated sets the "simulated" field. -func (duo *DecisionUpdateOne) SetSimulated(b bool) *DecisionUpdateOne { - duo.mutation.SetSimulated(b) - return duo -} - -// SetNillableSimulated sets the "simulated" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableSimulated(b *bool) *DecisionUpdateOne { - if b != nil { - duo.SetSimulated(*b) - } - return duo -} - -// SetUUID sets the "uuid" field. -func (duo *DecisionUpdateOne) SetUUID(s string) *DecisionUpdateOne { - duo.mutation.SetUUID(s) - return duo -} - -// SetNillableUUID sets the "uuid" field if the given value is not nil. -func (duo *DecisionUpdateOne) SetNillableUUID(s *string) *DecisionUpdateOne { - if s != nil { - duo.SetUUID(*s) - } - return duo -} - -// ClearUUID clears the value of the "uuid" field. -func (duo *DecisionUpdateOne) ClearUUID() *DecisionUpdateOne { - duo.mutation.ClearUUID() - return duo -} - // SetAlertDecisions sets the "alert_decisions" field. func (duo *DecisionUpdateOne) SetAlertDecisions(i int) *DecisionUpdateOne { duo.mutation.SetAlertDecisions(i) @@ -948,6 +302,12 @@ func (duo *DecisionUpdateOne) ClearOwner() *DecisionUpdateOne { return duo } +// Where appends a list predicates to the DecisionUpdate builder. +func (duo *DecisionUpdateOne) Where(ps ...predicate.Decision) *DecisionUpdateOne { + duo.mutation.Where(ps...) + return duo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUpdateOne { @@ -957,41 +317,8 @@ func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUp // Save executes the query and returns the updated Decision entity. func (duo *DecisionUpdateOne) Save(ctx context.Context) (*Decision, error) { - var ( - err error - node *Decision - ) duo.defaults() - if len(duo.hooks) == 0 { - node, err = duo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - duo.mutation = mutation - node, err = duo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(duo.hooks) - 1; i >= 0; i-- { - if duo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = duo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, duo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Decision) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, duo.sqlSave, duo.mutation, duo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -1018,27 +345,14 @@ func (duo *DecisionUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (duo *DecisionUpdateOne) defaults() { - if _, ok := duo.mutation.CreatedAt(); !ok && !duo.mutation.CreatedAtCleared() { - v := decision.UpdateDefaultCreatedAt() - duo.mutation.SetCreatedAt(v) - } - if _, ok := duo.mutation.UpdatedAt(); !ok && !duo.mutation.UpdatedAtCleared() { + if _, ok := duo.mutation.UpdatedAt(); !ok { v := decision.UpdateDefaultUpdatedAt() duo.mutation.SetUpdatedAt(v) } } func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: decision.Table, - Columns: decision.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: decision.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt)) id, ok := duo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Decision.id" for update`)} @@ -1063,199 +377,32 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } } } - if value, ok := duo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldCreatedAt, - }) - } - if duo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldCreatedAt, - }) - } if value, ok := duo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUpdatedAt, - }) - } - if duo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUpdatedAt, - }) + _spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value) } if value, ok := duo.mutation.Until(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: decision.FieldUntil, - }) + _spec.SetField(decision.FieldUntil, field.TypeTime, value) } if duo.mutation.UntilCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: decision.FieldUntil, - }) - } - if value, ok := duo.mutation.Scenario(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScenario, - }) - } - if value, ok := duo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldType, - }) - } - if value, ok := duo.mutation.StartIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) - } - if value, ok := duo.mutation.AddedStartIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartIP, - }) + _spec.ClearField(decision.FieldUntil, field.TypeTime) } if duo.mutation.StartIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartIP, - }) - } - if value, ok := duo.mutation.EndIP(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) - } - if value, ok := duo.mutation.AddedEndIP(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndIP, - }) + _spec.ClearField(decision.FieldStartIP, field.TypeInt64) } if duo.mutation.EndIPCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndIP, - }) - } - if value, ok := duo.mutation.StartSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := duo.mutation.AddedStartSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldStartSuffix, - }) + _spec.ClearField(decision.FieldEndIP, field.TypeInt64) } if duo.mutation.StartSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldStartSuffix, - }) - } - if value, ok := duo.mutation.EndSuffix(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := duo.mutation.AddedEndSuffix(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldEndSuffix, - }) + _spec.ClearField(decision.FieldStartSuffix, field.TypeInt64) } if duo.mutation.EndSuffixCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldEndSuffix, - }) - } - if value, ok := duo.mutation.IPSize(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) - } - if value, ok := duo.mutation.AddedIPSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Value: value, - Column: decision.FieldIPSize, - }) + _spec.ClearField(decision.FieldEndSuffix, field.TypeInt64) } if duo.mutation.IPSizeCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeInt64, - Column: decision.FieldIPSize, - }) - } - if value, ok := duo.mutation.Scope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldScope, - }) - } - if value, ok := duo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldValue, - }) - } - if value, ok := duo.mutation.Origin(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldOrigin, - }) - } - if value, ok := duo.mutation.Simulated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: decision.FieldSimulated, - }) - } - if value, ok := duo.mutation.UUID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldIPSize, field.TypeInt64) } if duo.mutation.UUIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: decision.FieldUUID, - }) + _spec.ClearField(decision.FieldUUID, field.TypeString) } if duo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1265,10 +412,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1281,10 +425,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err Columns: []string{decision.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1303,5 +444,6 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err } return nil, err } + duo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go index 0455af444d2..2a5ad188197 100644 --- a/pkg/database/ent/ent.go +++ b/pkg/database/ent/ent.go @@ -6,6 +6,8 @@ import ( "context" "errors" "fmt" + "reflect" + "sync" "entgo.io/ent" "entgo.io/ent/dialect/sql" @@ -15,56 +17,89 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + // OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. type OrderFunc func(*sql.Selector) -// columnChecker returns a function indicates if the column exists in the given column. -func columnChecker(table string) func(string) error { - checks := map[string]func(string) bool{ - alert.Table: alert.ValidColumn, - bouncer.Table: bouncer.ValidColumn, - configitem.Table: configitem.ValidColumn, - decision.Table: decision.ValidColumn, - event.Table: event.ValidColumn, - machine.Table: machine.ValidColumn, - meta.Table: meta.ValidColumn, - } - check, ok := checks[table] - if !ok { - return func(string) error { - return fmt.Errorf("unknown table %q", table) - } - } - return func(column string) error { - if !check(column) { - return fmt.Errorf("unknown column %q for table %q", column, table) - } - return nil - } +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + alert.Table: alert.ValidColumn, + bouncer.Table: bouncer.ValidColumn, + configitem.Table: configitem.ValidColumn, + decision.Table: decision.ValidColumn, + event.Table: event.ValidColumn, + lock.Table: lock.ValidColumn, + machine.Table: machine.ValidColumn, + meta.Table: meta.ValidColumn, + metric.Table: metric.ValidColumn, + }) + }) + return columnCheck(table, column) } // Asc applies the given fields in ASC order. -func Asc(fields ...string) OrderFunc { +func Asc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) @@ -73,11 +108,10 @@ func Asc(fields ...string) OrderFunc { } // Desc applies the given fields in DESC order. -func Desc(fields ...string) OrderFunc { +func Desc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) @@ -109,8 +143,7 @@ func Count() AggregateFunc { // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -121,8 +154,7 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -133,8 +165,7 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -145,8 +176,7 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -275,6 +305,7 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string + fns []AggregateFunc scan func(context.Context, any) error } @@ -473,5 +504,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/pkg/database/ent/event.go b/pkg/database/ent/event.go index 4754107fddc..b57f1f34ac9 100644 --- a/pkg/database/ent/event.go +++ b/pkg/database/ent/event.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" @@ -18,9 +19,9 @@ type Event struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Time holds the value of the "time" field. Time time.Time `json:"time,omitempty"` // Serialized holds the value of the "serialized" field. @@ -29,7 +30,8 @@ type Event struct { AlertEvents int `json:"alert_events,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the EventQuery when eager-loading is set. - Edges EventEdges `json:"edges"` + Edges EventEdges `json:"edges"` + selectValues sql.SelectValues } // EventEdges holds the relations/edges for other nodes in the graph. @@ -44,12 +46,10 @@ type EventEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e EventEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -66,7 +66,7 @@ func (*Event) scanValues(columns []string) ([]any, error) { case event.FieldCreatedAt, event.FieldUpdatedAt, event.FieldTime: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Event", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -90,15 +90,13 @@ func (e *Event) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - e.CreatedAt = new(time.Time) - *e.CreatedAt = value.Time + e.CreatedAt = value.Time } case event.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - e.UpdatedAt = new(time.Time) - *e.UpdatedAt = value.Time + e.UpdatedAt = value.Time } case event.FieldTime: if value, ok := values[i].(*sql.NullTime); !ok { @@ -118,21 +116,29 @@ func (e *Event) assignValues(columns []string, values []any) error { } else if value.Valid { e.AlertEvents = int(value.Int64) } + default: + e.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Event. +// This includes values selected through modifiers, order, etc. +func (e *Event) Value(name string) (ent.Value, error) { + return e.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Event entity. func (e *Event) QueryOwner() *AlertQuery { - return (&EventClient{config: e.config}).QueryOwner(e) + return NewEventClient(e.config).QueryOwner(e) } // Update returns a builder for updating this Event. // Note that you need to call Event.Unwrap() before calling this method if this Event // was returned from a transaction, and the transaction was committed or rolled back. func (e *Event) Update() *EventUpdateOne { - return (&EventClient{config: e.config}).UpdateOne(e) + return NewEventClient(e.config).UpdateOne(e) } // Unwrap unwraps the Event entity that was returned from a transaction after it was closed, @@ -151,15 +157,11 @@ func (e *Event) String() string { var builder strings.Builder builder.WriteString("Event(") builder.WriteString(fmt.Sprintf("id=%v, ", e.ID)) - if v := e.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(e.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := e.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(e.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("time=") builder.WriteString(e.Time.Format(time.ANSIC)) @@ -175,9 +177,3 @@ func (e *Event) String() string { // Events is a parsable slice of Event. type Events []*Event - -func (e Events) config(cfg config) { - for _i := range e { - e[_i].config = cfg - } -} diff --git a/pkg/database/ent/event/event.go b/pkg/database/ent/event/event.go index 33b9b67f8b9..c975a612669 100644 --- a/pkg/database/ent/event/event.go +++ b/pkg/database/ent/event/event.go @@ -4,6 +4,9 @@ package event import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,8 +60,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -66,3 +67,50 @@ var ( // SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. SerializedValidator func(string) error ) + +// OrderOption defines the ordering options for the Event queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByTime orders the results by the time field. +func ByTime(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTime, opts...).ToFunc() +} + +// BySerialized orders the results by the serialized field. +func BySerialized(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSerialized, opts...).ToFunc() +} + +// ByAlertEvents orders the results by the alert_events field. +func ByAlertEvents(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertEvents, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/event/where.go b/pkg/database/ent/event/where.go index 7554e59e678..d420b125026 100644 --- a/pkg/database/ent/event/where.go +++ b/pkg/database/ent/event/where.go @@ -12,477 +12,287 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Event(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // Time applies equality check predicate on the "time" field. It's identical to TimeEQ. func Time(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // Serialized applies equality check predicate on the "serialized" field. It's identical to SerializedEQ. func Serialized(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // AlertEvents applies equality check predicate on the "alert_events" field. It's identical to AlertEventsEQ. func AlertEvents(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Event(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Event(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Event(sql.FieldLTE(FieldUpdatedAt, v)) } // TimeEQ applies the EQ predicate on the "time" field. func TimeEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldEQ(FieldTime, v)) } // TimeNEQ applies the NEQ predicate on the "time" field. func TimeNEQ(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldTime, v)) } // TimeIn applies the In predicate on the "time" field. func TimeIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldIn(FieldTime, vs...)) } // TimeNotIn applies the NotIn predicate on the "time" field. func TimeNotIn(vs ...time.Time) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldTime), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldTime, vs...)) } // TimeGT applies the GT predicate on the "time" field. func TimeGT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGT(FieldTime, v)) } // TimeGTE applies the GTE predicate on the "time" field. func TimeGTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldGTE(FieldTime, v)) } // TimeLT applies the LT predicate on the "time" field. func TimeLT(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLT(FieldTime, v)) } // TimeLTE applies the LTE predicate on the "time" field. func TimeLTE(v time.Time) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldTime), v)) - }) + return predicate.Event(sql.FieldLTE(FieldTime, v)) } // SerializedEQ applies the EQ predicate on the "serialized" field. func SerializedEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEQ(FieldSerialized, v)) } // SerializedNEQ applies the NEQ predicate on the "serialized" field. func SerializedNEQ(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldSerialized, v)) } // SerializedIn applies the In predicate on the "serialized" field. func SerializedIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldIn(FieldSerialized, vs...)) } // SerializedNotIn applies the NotIn predicate on the "serialized" field. func SerializedNotIn(vs ...string) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSerialized), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldSerialized, vs...)) } // SerializedGT applies the GT predicate on the "serialized" field. func SerializedGT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGT(FieldSerialized, v)) } // SerializedGTE applies the GTE predicate on the "serialized" field. func SerializedGTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldGTE(FieldSerialized, v)) } // SerializedLT applies the LT predicate on the "serialized" field. func SerializedLT(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLT(FieldSerialized, v)) } // SerializedLTE applies the LTE predicate on the "serialized" field. func SerializedLTE(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldLTE(FieldSerialized, v)) } // SerializedContains applies the Contains predicate on the "serialized" field. func SerializedContains(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContains(FieldSerialized, v)) } // SerializedHasPrefix applies the HasPrefix predicate on the "serialized" field. func SerializedHasPrefix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasPrefix(FieldSerialized, v)) } // SerializedHasSuffix applies the HasSuffix predicate on the "serialized" field. func SerializedHasSuffix(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldHasSuffix(FieldSerialized, v)) } // SerializedEqualFold applies the EqualFold predicate on the "serialized" field. func SerializedEqualFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldEqualFold(FieldSerialized, v)) } // SerializedContainsFold applies the ContainsFold predicate on the "serialized" field. func SerializedContainsFold(v string) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSerialized), v)) - }) + return predicate.Event(sql.FieldContainsFold(FieldSerialized, v)) } // AlertEventsEQ applies the EQ predicate on the "alert_events" field. func AlertEventsEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldEQ(FieldAlertEvents, v)) } // AlertEventsNEQ applies the NEQ predicate on the "alert_events" field. func AlertEventsNEQ(v int) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertEvents), v)) - }) + return predicate.Event(sql.FieldNEQ(FieldAlertEvents, v)) } // AlertEventsIn applies the In predicate on the "alert_events" field. func AlertEventsIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldIn(FieldAlertEvents, vs...)) } // AlertEventsNotIn applies the NotIn predicate on the "alert_events" field. func AlertEventsNotIn(vs ...int) predicate.Event { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertEvents), v...)) - }) + return predicate.Event(sql.FieldNotIn(FieldAlertEvents, vs...)) } // AlertEventsIsNil applies the IsNil predicate on the "alert_events" field. func AlertEventsIsNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldIsNull(FieldAlertEvents)) } // AlertEventsNotNil applies the NotNil predicate on the "alert_events" field. func AlertEventsNotNil() predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertEvents))) - }) + return predicate.Event(sql.FieldNotNull(FieldAlertEvents)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -490,7 +300,6 @@ func HasOwner() predicate.Event { return predicate.Event(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -500,11 +309,7 @@ func HasOwner() predicate.Event { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Event { return predicate.Event(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -515,32 +320,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Event { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Event(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Event) predicate.Event { - return predicate.Event(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Event(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/event_create.go b/pkg/database/ent/event_create.go index c5861305130..36747babe47 100644 --- a/pkg/database/ent/event_create.go +++ b/pkg/database/ent/event_create.go @@ -101,50 +101,8 @@ func (ec *EventCreate) Mutation() *EventMutation { // Save creates the Event in the database. func (ec *EventCreate) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) ec.defaults() - if len(ec.hooks) == 0 { - if err = ec.check(); err != nil { - return nil, err - } - node, err = ec.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ec.check(); err != nil { - return nil, err - } - ec.mutation = mutation - if node, err = ec.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ec.hooks) - 1; i >= 0; i-- { - if ec.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ec.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ec.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ec.sqlSave, ec.mutation, ec.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -183,6 +141,12 @@ func (ec *EventCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (ec *EventCreate) check() error { + if _, ok := ec.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Event.created_at"`)} + } + if _, ok := ec.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Event.updated_at"`)} + } if _, ok := ec.mutation.Time(); !ok { return &ValidationError{Name: "time", err: errors.New(`ent: missing required field "Event.time"`)} } @@ -198,6 +162,9 @@ func (ec *EventCreate) check() error { } func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { + if err := ec.check(); err != nil { + return nil, err + } _node, _spec := ec.createSpec() if err := sqlgraph.CreateNode(ctx, ec.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +174,30 @@ func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + ec.mutation.id = &_node.ID + ec.mutation.done = true return _node, nil } func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { var ( _node = &Event{config: ec.config} - _spec = &sqlgraph.CreateSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) ) if value, ok := ec.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(event.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := ec.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := ec.mutation.Time(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) + _spec.SetField(event.FieldTime, field.TypeTime, value) _node.Time = value } if value, ok := ec.mutation.Serialized(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldSerialized, field.TypeString, value) _node.Serialized = value } if nodes := ec.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +208,7 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +223,15 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) { // EventCreateBulk is the builder for creating many Event entities in bulk. type EventCreateBulk struct { config + err error builders []*EventCreate } // Save creates the Event entities in the database. func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { + if ecb.err != nil { + return nil, ecb.err + } specs := make([]*sqlgraph.CreateSpec, len(ecb.builders)) nodes := make([]*Event, len(ecb.builders)) mutators := make([]Mutator, len(ecb.builders)) @@ -300,8 +248,8 @@ func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, ecb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/event_delete.go b/pkg/database/ent/event_delete.go index 0220dc71d31..93dd1246b7e 100644 --- a/pkg/database/ent/event_delete.go +++ b/pkg/database/ent/event_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ed *EventDelete) Where(ps ...predicate.Event) *EventDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ed *EventDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ed.hooks) == 0 { - affected, err = ed.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ed.mutation = mutation - affected, err = ed.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ed.hooks) - 1; i >= 0; i-- { - if ed.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ed.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ed.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ed.sqlExec, ed.mutation, ed.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ed *EventDelete) ExecX(ctx context.Context) int { } func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := ed.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ed.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type EventDeleteOne struct { ed *EventDelete } +// Where appends a list predicates to the EventDelete builder. +func (edo *EventDeleteOne) Where(ps ...predicate.Event) *EventDeleteOne { + edo.ed.mutation.Where(ps...) + return edo +} + // Exec executes the deletion query. func (edo *EventDeleteOne) Exec(ctx context.Context) error { n, err := edo.ed.Exec(ctx) @@ -111,5 +82,7 @@ func (edo *EventDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (edo *EventDeleteOne) ExecX(ctx context.Context) { - edo.ed.ExecX(ctx) + if err := edo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/event_query.go b/pkg/database/ent/event_query.go index 045d750f818..1493d7bd32c 100644 --- a/pkg/database/ent/event_query.go +++ b/pkg/database/ent/event_query.go @@ -18,11 +18,9 @@ import ( // EventQuery is the builder for querying Event entities. type EventQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []event.OrderOption + inters []Interceptor predicates []predicate.Event withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (eq *EventQuery) Where(ps ...predicate.Event) *EventQuery { return eq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (eq *EventQuery) Limit(limit int) *EventQuery { - eq.limit = &limit + eq.ctx.Limit = &limit return eq } -// Offset adds an offset step to the query. +// Offset to start from. func (eq *EventQuery) Offset(offset int) *EventQuery { - eq.offset = &offset + eq.ctx.Offset = &offset return eq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (eq *EventQuery) Unique(unique bool) *EventQuery { - eq.unique = &unique + eq.ctx.Unique = &unique return eq } -// Order adds an order step to the query. -func (eq *EventQuery) Order(o ...OrderFunc) *EventQuery { +// Order specifies how the records should be ordered. +func (eq *EventQuery) Order(o ...event.OrderOption) *EventQuery { eq.order = append(eq.order, o...) return eq } // QueryOwner chains the current query on the "owner" edge. func (eq *EventQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := eq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (eq *EventQuery) QueryOwner() *AlertQuery { // First returns the first Event entity from the query. // Returns a *NotFoundError when no Event was found. func (eq *EventQuery) First(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(1).All(ctx) + nodes, err := eq.Limit(1).All(setContextOp(ctx, eq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (eq *EventQuery) FirstX(ctx context.Context) *Event { // Returns a *NotFoundError when no Event ID was found. func (eq *EventQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(1).IDs(ctx); err != nil { + if ids, err = eq.Limit(1).IDs(setContextOp(ctx, eq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (eq *EventQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Event entity is found. // Returns a *NotFoundError when no Event entities are found. func (eq *EventQuery) Only(ctx context.Context) (*Event, error) { - nodes, err := eq.Limit(2).All(ctx) + nodes, err := eq.Limit(2).All(setContextOp(ctx, eq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (eq *EventQuery) OnlyX(ctx context.Context) *Event { // Returns a *NotFoundError when no entities are found. func (eq *EventQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = eq.Limit(2).IDs(ctx); err != nil { + if ids, err = eq.Limit(2).IDs(setContextOp(ctx, eq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (eq *EventQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Events. func (eq *EventQuery) All(ctx context.Context) ([]*Event, error) { + ctx = setContextOp(ctx, eq.ctx, "All") if err := eq.prepareQuery(ctx); err != nil { return nil, err } - return eq.sqlAll(ctx) + qr := querierAll[[]*Event, *EventQuery]() + return withInterceptors[[]*Event](ctx, eq, qr, eq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (eq *EventQuery) AllX(ctx context.Context) []*Event { } // IDs executes the query and returns a list of Event IDs. -func (eq *EventQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { +func (eq *EventQuery) IDs(ctx context.Context) (ids []int, err error) { + if eq.ctx.Unique == nil && eq.path != nil { + eq.Unique(true) + } + ctx = setContextOp(ctx, eq.ctx, "IDs") + if err = eq.Select(event.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (eq *EventQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (eq *EventQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, eq.ctx, "Count") if err := eq.prepareQuery(ctx); err != nil { return 0, err } - return eq.sqlCount(ctx) + return withInterceptors[int](ctx, eq, querierCount[*EventQuery](), eq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (eq *EventQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (eq *EventQuery) Exist(ctx context.Context) (bool, error) { - if err := eq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, eq.ctx, "Exist") + switch _, err := eq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return eq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (eq *EventQuery) Clone() *EventQuery { } return &EventQuery{ config: eq.config, - limit: eq.limit, - offset: eq.offset, - order: append([]OrderFunc{}, eq.order...), + ctx: eq.ctx.Clone(), + order: append([]event.OrderOption{}, eq.order...), + inters: append([]Interceptor{}, eq.inters...), predicates: append([]predicate.Event{}, eq.predicates...), withOwner: eq.withOwner.Clone(), // clone intermediate query. - sql: eq.sql.Clone(), - path: eq.path, - unique: eq.unique, + sql: eq.sql.Clone(), + path: eq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { - query := &AlertQuery{config: eq.config} + query := (&AlertClient{config: eq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { - grbuild := &EventGroupBy{config: eq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := eq.prepareQuery(ctx); err != nil { - return nil, err - } - return eq.sqlQuery(ctx), nil - } + eq.ctx.Fields = append([]string{field}, fields...) + grbuild := &EventGroupBy{build: eq} + grbuild.flds = &eq.ctx.Fields grbuild.label = event.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy { // Select(event.FieldCreatedAt). // Scan(ctx, &v) func (eq *EventQuery) Select(fields ...string) *EventSelect { - eq.fields = append(eq.fields, fields...) - selbuild := &EventSelect{EventQuery: eq} - selbuild.label = event.Label - selbuild.flds, selbuild.scan = &eq.fields, selbuild.Scan - return selbuild + eq.ctx.Fields = append(eq.ctx.Fields, fields...) + sbuild := &EventSelect{EventQuery: eq} + sbuild.label = event.Label + sbuild.flds, sbuild.scan = &eq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a EventSelect configured with the given aggregations. +func (eq *EventQuery) Aggregate(fns ...AggregateFunc) *EventSelect { + return eq.Select().Aggregate(fns...) } func (eq *EventQuery) prepareQuery(ctx context.Context) error { - for _, f := range eq.fields { + for _, inter := range eq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, eq); err != nil { + return err + } + } + } + for _, f := range eq.ctx.Fields { if !event.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes [] func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) { _spec := eq.querySpec() - _spec.Node.Columns = eq.fields - if len(eq.fields) > 0 { - _spec.Unique = eq.unique != nil && *eq.unique + _spec.Node.Columns = eq.ctx.Fields + if len(eq.ctx.Fields) > 0 { + _spec.Unique = eq.ctx.Unique != nil && *eq.ctx.Unique } return sqlgraph.CountNodes(ctx, eq.driver, _spec) } -func (eq *EventQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := eq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - From: eq.sql, - Unique: true, - } - if unique := eq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) + _spec.From = eq.sql + if unique := eq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if eq.path != nil { + _spec.Unique = true } - if fields := eq.fields; len(fields) > 0 { + if fields := eq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, event.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if eq.withOwner != nil { + _spec.Node.AddColumnOnce(event.FieldAlertEvents) + } } if ps := eq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := eq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec { func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(eq.driver.Dialect()) t1 := builder.Table(event.Table) - columns := eq.fields + columns := eq.ctx.Fields if len(columns) == 0 { columns = event.Columns } @@ -489,7 +494,7 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = eq.sql selector.Select(selector.Columns(columns...)...) } - if eq.unique != nil && *eq.unique { + if eq.ctx.Unique != nil && *eq.ctx.Unique { selector.Distinct() } for _, p := range eq.predicates { @@ -498,12 +503,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range eq.order { p(selector) } - if offset := eq.offset; offset != nil { + if offset := eq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := eq.limit; limit != nil { + if limit := eq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector { // EventGroupBy is the group-by builder for Event entities. type EventGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *EventQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (egb *EventGroupBy) Aggregate(fns ...AggregateFunc) *EventGroupBy { return egb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (egb *EventGroupBy) Scan(ctx context.Context, v any) error { - query, err := egb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, egb.build.ctx, "GroupBy") + if err := egb.build.prepareQuery(ctx); err != nil { return err } - egb.sql = query - return egb.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventGroupBy](ctx, egb.build, egb, egb.build.inters, v) } -func (egb *EventGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range egb.fields { - if !event.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (egb *EventGroupBy) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(egb.fns)) + for _, fn := range egb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*egb.flds)+len(egb.fns)) + for _, f := range *egb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := egb.sqlQuery() + selector.GroupBy(selector.Columns(*egb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := egb.driver.Query(ctx, query, args, rows); err != nil { + if err := egb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (egb *EventGroupBy) sqlQuery() *sql.Selector { - selector := egb.sql.Select() - aggregation := make([]string, 0, len(egb.fns)) - for _, fn := range egb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(egb.fields)+len(egb.fns)) - for _, f := range egb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(egb.fields...)...) -} - // EventSelect is the builder for selecting fields of Event entities. type EventSelect struct { *EventQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (es *EventSelect) Aggregate(fns ...AggregateFunc) *EventSelect { + es.fns = append(es.fns, fns...) + return es } // Scan applies the selector query and scans the result into the given value. func (es *EventSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, es.ctx, "Select") if err := es.prepareQuery(ctx); err != nil { return err } - es.sql = es.EventQuery.sqlQuery(ctx) - return es.sqlScan(ctx, v) + return scanWithInterceptors[*EventQuery, *EventSelect](ctx, es.EventQuery, es, es.inters, v) } -func (es *EventSelect) sqlScan(ctx context.Context, v any) error { +func (es *EventSelect) sqlScan(ctx context.Context, root *EventQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(es.fns)) + for _, fn := range es.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*es.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := es.sql.Query() + query, args := selector.Query() if err := es.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/event_update.go b/pkg/database/ent/event_update.go index fcd0cc50c99..c2f5c6cddb1 100644 --- a/pkg/database/ent/event_update.go +++ b/pkg/database/ent/event_update.go @@ -29,42 +29,12 @@ func (eu *EventUpdate) Where(ps ...predicate.Event) *EventUpdate { return eu } -// SetCreatedAt sets the "created_at" field. -func (eu *EventUpdate) SetCreatedAt(t time.Time) *EventUpdate { - eu.mutation.SetCreatedAt(t) - return eu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (eu *EventUpdate) ClearCreatedAt() *EventUpdate { - eu.mutation.ClearCreatedAt() - return eu -} - // SetUpdatedAt sets the "updated_at" field. func (eu *EventUpdate) SetUpdatedAt(t time.Time) *EventUpdate { eu.mutation.SetUpdatedAt(t) return eu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (eu *EventUpdate) ClearUpdatedAt() *EventUpdate { - eu.mutation.ClearUpdatedAt() - return eu -} - -// SetTime sets the "time" field. -func (eu *EventUpdate) SetTime(t time.Time) *EventUpdate { - eu.mutation.SetTime(t) - return eu -} - -// SetSerialized sets the "serialized" field. -func (eu *EventUpdate) SetSerialized(s string) *EventUpdate { - eu.mutation.SetSerialized(s) - return eu -} - // SetAlertEvents sets the "alert_events" field. func (eu *EventUpdate) SetAlertEvents(i int) *EventUpdate { eu.mutation.SetAlertEvents(i) @@ -117,41 +87,8 @@ func (eu *EventUpdate) ClearOwner() *EventUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (eu *EventUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) eu.defaults() - if len(eu.hooks) == 0 { - if err = eu.check(); err != nil { - return 0, err - } - affected, err = eu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = eu.check(); err != nil { - return 0, err - } - eu.mutation = mutation - affected, err = eu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(eu.hooks) - 1; i >= 0; i-- { - if eu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = eu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, eu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, eu.sqlSave, eu.mutation, eu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -178,37 +115,14 @@ func (eu *EventUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (eu *EventUpdate) defaults() { - if _, ok := eu.mutation.CreatedAt(); !ok && !eu.mutation.CreatedAtCleared() { - v := event.UpdateDefaultCreatedAt() - eu.mutation.SetCreatedAt(v) - } - if _, ok := eu.mutation.UpdatedAt(); !ok && !eu.mutation.UpdatedAtCleared() { + if _, ok := eu.mutation.UpdatedAt(); !ok { v := event.UpdateDefaultUpdatedAt() eu.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (eu *EventUpdate) check() error { - if v, ok := eu.mutation.Serialized(); ok { - if err := event.SerializedValidator(v); err != nil { - return &ValidationError{Name: "serialized", err: fmt.Errorf(`ent: validator failed for field "Event.serialized": %w`, err)} - } - } - return nil -} - func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) if ps := eu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -216,45 +130,8 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := eu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - } - if eu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) - } if value, ok := eu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - } - if eu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) - } - if value, ok := eu.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) - } - if value, ok := eu.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if eu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +141,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +154,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +170,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + eu.mutation.done = true return n, nil } @@ -310,42 +182,12 @@ type EventUpdateOne struct { mutation *EventMutation } -// SetCreatedAt sets the "created_at" field. -func (euo *EventUpdateOne) SetCreatedAt(t time.Time) *EventUpdateOne { - euo.mutation.SetCreatedAt(t) - return euo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (euo *EventUpdateOne) ClearCreatedAt() *EventUpdateOne { - euo.mutation.ClearCreatedAt() - return euo -} - // SetUpdatedAt sets the "updated_at" field. func (euo *EventUpdateOne) SetUpdatedAt(t time.Time) *EventUpdateOne { euo.mutation.SetUpdatedAt(t) return euo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (euo *EventUpdateOne) ClearUpdatedAt() *EventUpdateOne { - euo.mutation.ClearUpdatedAt() - return euo -} - -// SetTime sets the "time" field. -func (euo *EventUpdateOne) SetTime(t time.Time) *EventUpdateOne { - euo.mutation.SetTime(t) - return euo -} - -// SetSerialized sets the "serialized" field. -func (euo *EventUpdateOne) SetSerialized(s string) *EventUpdateOne { - euo.mutation.SetSerialized(s) - return euo -} - // SetAlertEvents sets the "alert_events" field. func (euo *EventUpdateOne) SetAlertEvents(i int) *EventUpdateOne { euo.mutation.SetAlertEvents(i) @@ -396,6 +238,12 @@ func (euo *EventUpdateOne) ClearOwner() *EventUpdateOne { return euo } +// Where appends a list predicates to the EventUpdate builder. +func (euo *EventUpdateOne) Where(ps ...predicate.Event) *EventUpdateOne { + euo.mutation.Where(ps...) + return euo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOne { @@ -405,47 +253,8 @@ func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOn // Save executes the query and returns the updated Event entity. func (euo *EventUpdateOne) Save(ctx context.Context) (*Event, error) { - var ( - err error - node *Event - ) euo.defaults() - if len(euo.hooks) == 0 { - if err = euo.check(); err != nil { - return nil, err - } - node, err = euo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = euo.check(); err != nil { - return nil, err - } - euo.mutation = mutation - node, err = euo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(euo.hooks) - 1; i >= 0; i-- { - if euo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = euo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, euo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Event) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, euo.sqlSave, euo.mutation, euo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -472,37 +281,14 @@ func (euo *EventUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (euo *EventUpdateOne) defaults() { - if _, ok := euo.mutation.CreatedAt(); !ok && !euo.mutation.CreatedAtCleared() { - v := event.UpdateDefaultCreatedAt() - euo.mutation.SetCreatedAt(v) - } - if _, ok := euo.mutation.UpdatedAt(); !ok && !euo.mutation.UpdatedAtCleared() { + if _, ok := euo.mutation.UpdatedAt(); !ok { v := event.UpdateDefaultUpdatedAt() euo.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (euo *EventUpdateOne) check() error { - if v, ok := euo.mutation.Serialized(); ok { - if err := event.SerializedValidator(v); err != nil { - return &ValidationError{Name: "serialized", err: fmt.Errorf(`ent: validator failed for field "Event.serialized": %w`, err)} - } - } - return nil -} - func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: event.Table, - Columns: event.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: event.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt)) id, ok := euo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Event.id" for update`)} @@ -527,45 +313,8 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } } } - if value, ok := euo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldCreatedAt, - }) - } - if euo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldCreatedAt, - }) - } if value, ok := euo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldUpdatedAt, - }) - } - if euo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: event.FieldUpdatedAt, - }) - } - if value, ok := euo.mutation.Time(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: event.FieldTime, - }) - } - if value, ok := euo.mutation.Serialized(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: event.FieldSerialized, - }) + _spec.SetField(event.FieldUpdatedAt, field.TypeTime, value) } if euo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +324,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +337,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error Columns: []string{event.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +356,6 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error } return nil, err } + euo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/generate.go b/pkg/database/ent/generate.go index 9f3a916c7a4..8ada999d7ab 100644 --- a/pkg/database/ent/generate.go +++ b/pkg/database/ent/generate.go @@ -1,4 +1,4 @@ package ent -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent@v0.13.1 generate ./schema diff --git a/pkg/database/ent/helpers.go b/pkg/database/ent/helpers.go new file mode 100644 index 00000000000..9b30ce451e0 --- /dev/null +++ b/pkg/database/ent/helpers.go @@ -0,0 +1,25 @@ +package ent + +func (m *Machine) GetOsname() string { + return m.Osname +} + +func (b *Bouncer) GetOsname() string { + return b.Osname +} + +func (m *Machine) GetOsversion() string { + return m.Osversion +} + +func (b *Bouncer) GetOsversion() string { + return b.Osversion +} + +func (m *Machine) GetFeatureflags() string { + return m.Featureflags +} + +func (b *Bouncer) GetFeatureflags() string { + return b.Featureflags +} diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go index 85ab00b01fb..62cc07820d0 100644 --- a/pkg/database/ent/hook/hook.go +++ b/pkg/database/ent/hook/hook.go @@ -15,11 +15,10 @@ type AlertFunc func(context.Context, *ent.AlertMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.AlertMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) + if mv, ok := m.(*ent.AlertMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) } // The BouncerFunc type is an adapter to allow the use of ordinary @@ -28,11 +27,10 @@ type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f BouncerFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.BouncerMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) + if mv, ok := m.(*ent.BouncerMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m) } // The ConfigItemFunc type is an adapter to allow the use of ordinary @@ -41,11 +39,10 @@ type ConfigItemFunc func(context.Context, *ent.ConfigItemMutation) (ent.Value, e // Mutate calls f(ctx, m). func (f ConfigItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.ConfigItemMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) + if mv, ok := m.(*ent.ConfigItemMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m) } // The DecisionFunc type is an adapter to allow the use of ordinary @@ -54,11 +51,10 @@ type DecisionFunc func(context.Context, *ent.DecisionMutation) (ent.Value, error // Mutate calls f(ctx, m). func (f DecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DecisionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) + if mv, ok := m.(*ent.DecisionMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m) } // The EventFunc type is an adapter to allow the use of ordinary @@ -67,11 +63,22 @@ type EventFunc func(context.Context, *ent.EventMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f EventFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.EventMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) + if mv, ok := m.(*ent.EventMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m) +} + +// The LockFunc type is an adapter to allow the use of ordinary +// function as Lock mutator. +type LockFunc func(context.Context, *ent.LockMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f LockFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.LockMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.LockMutation", m) } // The MachineFunc type is an adapter to allow the use of ordinary @@ -80,11 +87,10 @@ type MachineFunc func(context.Context, *ent.MachineMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MachineFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) + if mv, ok := m.(*ent.MachineMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m) } // The MetaFunc type is an adapter to allow the use of ordinary @@ -93,11 +99,22 @@ type MetaFunc func(context.Context, *ent.MetaMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f MetaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) + if mv, ok := m.(*ent.MetaMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m) +} + +// The MetricFunc type is an adapter to allow the use of ordinary +// function as Metric mutator. +type MetricFunc func(context.Context, *ent.MetricMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f MetricFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.MetricMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetricMutation", m) } // Condition is a hook condition function. diff --git a/pkg/database/ent/lock.go b/pkg/database/ent/lock.go new file mode 100644 index 00000000000..85556a30644 --- /dev/null +++ b/pkg/database/ent/lock.go @@ -0,0 +1,117 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// Lock is the model entity for the Lock schema. +type Lock struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Lock) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case lock.FieldID: + values[i] = new(sql.NullInt64) + case lock.FieldName: + values[i] = new(sql.NullString) + case lock.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Lock fields. +func (l *Lock) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case lock.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + l.ID = int(value.Int64) + case lock.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + l.Name = value.String + } + case lock.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + l.CreatedAt = value.Time + } + default: + l.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Lock. +// This includes values selected through modifiers, order, etc. +func (l *Lock) Value(name string) (ent.Value, error) { + return l.selectValues.Get(name) +} + +// Update returns a builder for updating this Lock. +// Note that you need to call Lock.Unwrap() before calling this method if this Lock +// was returned from a transaction, and the transaction was committed or rolled back. +func (l *Lock) Update() *LockUpdateOne { + return NewLockClient(l.config).UpdateOne(l) +} + +// Unwrap unwraps the Lock entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (l *Lock) Unwrap() *Lock { + _tx, ok := l.config.driver.(*txDriver) + if !ok { + panic("ent: Lock is not a transactional entity") + } + l.config.driver = _tx.drv + return l +} + +// String implements the fmt.Stringer. +func (l *Lock) String() string { + var builder strings.Builder + builder.WriteString("Lock(") + builder.WriteString(fmt.Sprintf("id=%v, ", l.ID)) + builder.WriteString("name=") + builder.WriteString(l.Name) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(l.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// Locks is a parsable slice of Lock. +type Locks []*Lock diff --git a/pkg/database/ent/lock/lock.go b/pkg/database/ent/lock/lock.go new file mode 100644 index 00000000000..d0143470a75 --- /dev/null +++ b/pkg/database/ent/lock/lock.go @@ -0,0 +1,62 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the lock type in the database. + Label = "lock" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // Table holds the table name of the lock in the database. + Table = "locks" +) + +// Columns holds all SQL columns for lock fields. +var Columns = []string{ + FieldID, + FieldName, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the Lock queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} diff --git a/pkg/database/ent/lock/where.go b/pkg/database/ent/lock/where.go new file mode 100644 index 00000000000..cf59362d203 --- /dev/null +++ b/pkg/database/ent/lock/where.go @@ -0,0 +1,185 @@ +// Code generated by ent, DO NOT EDIT. + +package lock + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldID, id)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Lock { + return predicate.Lock(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Lock { + return predicate.Lock(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Lock { + return predicate.Lock(sql.FieldContainsFold(FieldName, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Lock { + return predicate.Lock(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Lock { + return predicate.Lock(sql.FieldLTE(FieldCreatedAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Lock) predicate.Lock { + return predicate.Lock(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Lock) predicate.Lock { + return predicate.Lock(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/lock_create.go b/pkg/database/ent/lock_create.go new file mode 100644 index 00000000000..e2c29c88324 --- /dev/null +++ b/pkg/database/ent/lock_create.go @@ -0,0 +1,215 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" +) + +// LockCreate is the builder for creating a Lock entity. +type LockCreate struct { + config + mutation *LockMutation + hooks []Hook +} + +// SetName sets the "name" field. +func (lc *LockCreate) SetName(s string) *LockCreate { + lc.mutation.SetName(s) + return lc +} + +// SetCreatedAt sets the "created_at" field. +func (lc *LockCreate) SetCreatedAt(t time.Time) *LockCreate { + lc.mutation.SetCreatedAt(t) + return lc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (lc *LockCreate) SetNillableCreatedAt(t *time.Time) *LockCreate { + if t != nil { + lc.SetCreatedAt(*t) + } + return lc +} + +// Mutation returns the LockMutation object of the builder. +func (lc *LockCreate) Mutation() *LockMutation { + return lc.mutation +} + +// Save creates the Lock in the database. +func (lc *LockCreate) Save(ctx context.Context) (*Lock, error) { + lc.defaults() + return withHooks(ctx, lc.sqlSave, lc.mutation, lc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (lc *LockCreate) SaveX(ctx context.Context) *Lock { + v, err := lc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lc *LockCreate) Exec(ctx context.Context) error { + _, err := lc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lc *LockCreate) ExecX(ctx context.Context) { + if err := lc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (lc *LockCreate) defaults() { + if _, ok := lc.mutation.CreatedAt(); !ok { + v := lock.DefaultCreatedAt() + lc.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (lc *LockCreate) check() error { + if _, ok := lc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Lock.name"`)} + } + if _, ok := lc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Lock.created_at"`)} + } + return nil +} + +func (lc *LockCreate) sqlSave(ctx context.Context) (*Lock, error) { + if err := lc.check(); err != nil { + return nil, err + } + _node, _spec := lc.createSpec() + if err := sqlgraph.CreateNode(ctx, lc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + lc.mutation.id = &_node.ID + lc.mutation.done = true + return _node, nil +} + +func (lc *LockCreate) createSpec() (*Lock, *sqlgraph.CreateSpec) { + var ( + _node = &Lock{config: lc.config} + _spec = sqlgraph.NewCreateSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + ) + if value, ok := lc.mutation.Name(); ok { + _spec.SetField(lock.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := lc.mutation.CreatedAt(); ok { + _spec.SetField(lock.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + return _node, _spec +} + +// LockCreateBulk is the builder for creating many Lock entities in bulk. +type LockCreateBulk struct { + config + err error + builders []*LockCreate +} + +// Save creates the Lock entities in the database. +func (lcb *LockCreateBulk) Save(ctx context.Context) ([]*Lock, error) { + if lcb.err != nil { + return nil, lcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(lcb.builders)) + nodes := make([]*Lock, len(lcb.builders)) + mutators := make([]Mutator, len(lcb.builders)) + for i := range lcb.builders { + func(i int, root context.Context) { + builder := lcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*LockMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, lcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, lcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, lcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (lcb *LockCreateBulk) SaveX(ctx context.Context) []*Lock { + v, err := lcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (lcb *LockCreateBulk) Exec(ctx context.Context) error { + _, err := lcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lcb *LockCreateBulk) ExecX(ctx context.Context) { + if err := lcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_delete.go b/pkg/database/ent/lock_delete.go new file mode 100644 index 00000000000..2275c608f75 --- /dev/null +++ b/pkg/database/ent/lock_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockDelete is the builder for deleting a Lock entity. +type LockDelete struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockDelete builder. +func (ld *LockDelete) Where(ps ...predicate.Lock) *LockDelete { + ld.mutation.Where(ps...) + return ld +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ld *LockDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ld.sqlExec, ld.mutation, ld.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ld *LockDelete) ExecX(ctx context.Context) int { + n, err := ld.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ld *LockDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(lock.Table, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := ld.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ld.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ld.mutation.done = true + return affected, err +} + +// LockDeleteOne is the builder for deleting a single Lock entity. +type LockDeleteOne struct { + ld *LockDelete +} + +// Where appends a list predicates to the LockDelete builder. +func (ldo *LockDeleteOne) Where(ps ...predicate.Lock) *LockDeleteOne { + ldo.ld.mutation.Where(ps...) + return ldo +} + +// Exec executes the deletion query. +func (ldo *LockDeleteOne) Exec(ctx context.Context) error { + n, err := ldo.ld.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{lock.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (ldo *LockDeleteOne) ExecX(ctx context.Context) { + if err := ldo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/lock_query.go b/pkg/database/ent/lock_query.go new file mode 100644 index 00000000000..75e5da48a94 --- /dev/null +++ b/pkg/database/ent/lock_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockQuery is the builder for querying Lock entities. +type LockQuery struct { + config + ctx *QueryContext + order []lock.OrderOption + inters []Interceptor + predicates []predicate.Lock + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the LockQuery builder. +func (lq *LockQuery) Where(ps ...predicate.Lock) *LockQuery { + lq.predicates = append(lq.predicates, ps...) + return lq +} + +// Limit the number of records to be returned by this query. +func (lq *LockQuery) Limit(limit int) *LockQuery { + lq.ctx.Limit = &limit + return lq +} + +// Offset to start from. +func (lq *LockQuery) Offset(offset int) *LockQuery { + lq.ctx.Offset = &offset + return lq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (lq *LockQuery) Unique(unique bool) *LockQuery { + lq.ctx.Unique = &unique + return lq +} + +// Order specifies how the records should be ordered. +func (lq *LockQuery) Order(o ...lock.OrderOption) *LockQuery { + lq.order = append(lq.order, o...) + return lq +} + +// First returns the first Lock entity from the query. +// Returns a *NotFoundError when no Lock was found. +func (lq *LockQuery) First(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(1).All(setContextOp(ctx, lq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{lock.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (lq *LockQuery) FirstX(ctx context.Context) *Lock { + node, err := lq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Lock ID from the query. +// Returns a *NotFoundError when no Lock ID was found. +func (lq *LockQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(1).IDs(setContextOp(ctx, lq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{lock.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (lq *LockQuery) FirstIDX(ctx context.Context) int { + id, err := lq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Lock entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Lock entity is found. +// Returns a *NotFoundError when no Lock entities are found. +func (lq *LockQuery) Only(ctx context.Context) (*Lock, error) { + nodes, err := lq.Limit(2).All(setContextOp(ctx, lq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{lock.Label} + default: + return nil, &NotSingularError{lock.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (lq *LockQuery) OnlyX(ctx context.Context) *Lock { + node, err := lq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Lock ID in the query. +// Returns a *NotSingularError when more than one Lock ID is found. +// Returns a *NotFoundError when no entities are found. +func (lq *LockQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = lq.Limit(2).IDs(setContextOp(ctx, lq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{lock.Label} + default: + err = &NotSingularError{lock.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (lq *LockQuery) OnlyIDX(ctx context.Context) int { + id, err := lq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Locks. +func (lq *LockQuery) All(ctx context.Context) ([]*Lock, error) { + ctx = setContextOp(ctx, lq.ctx, "All") + if err := lq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Lock, *LockQuery]() + return withInterceptors[[]*Lock](ctx, lq, qr, lq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (lq *LockQuery) AllX(ctx context.Context) []*Lock { + nodes, err := lq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Lock IDs. +func (lq *LockQuery) IDs(ctx context.Context) (ids []int, err error) { + if lq.ctx.Unique == nil && lq.path != nil { + lq.Unique(true) + } + ctx = setContextOp(ctx, lq.ctx, "IDs") + if err = lq.Select(lock.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (lq *LockQuery) IDsX(ctx context.Context) []int { + ids, err := lq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (lq *LockQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, lq.ctx, "Count") + if err := lq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, lq, querierCount[*LockQuery](), lq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (lq *LockQuery) CountX(ctx context.Context) int { + count, err := lq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (lq *LockQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, lq.ctx, "Exist") + switch _, err := lq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (lq *LockQuery) ExistX(ctx context.Context) bool { + exist, err := lq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the LockQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (lq *LockQuery) Clone() *LockQuery { + if lq == nil { + return nil + } + return &LockQuery{ + config: lq.config, + ctx: lq.ctx.Clone(), + order: append([]lock.OrderOption{}, lq.order...), + inters: append([]Interceptor{}, lq.inters...), + predicates: append([]predicate.Lock{}, lq.predicates...), + // clone intermediate query. + sql: lq.sql.Clone(), + path: lq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// Count int `json:"count,omitempty"` +// } +// +// client.Lock.Query(). +// GroupBy(lock.FieldName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (lq *LockQuery) GroupBy(field string, fields ...string) *LockGroupBy { + lq.ctx.Fields = append([]string{field}, fields...) + grbuild := &LockGroupBy{build: lq} + grbuild.flds = &lq.ctx.Fields + grbuild.label = lock.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Name string `json:"name"` +// } +// +// client.Lock.Query(). +// Select(lock.FieldName). +// Scan(ctx, &v) +func (lq *LockQuery) Select(fields ...string) *LockSelect { + lq.ctx.Fields = append(lq.ctx.Fields, fields...) + sbuild := &LockSelect{LockQuery: lq} + sbuild.label = lock.Label + sbuild.flds, sbuild.scan = &lq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a LockSelect configured with the given aggregations. +func (lq *LockQuery) Aggregate(fns ...AggregateFunc) *LockSelect { + return lq.Select().Aggregate(fns...) +} + +func (lq *LockQuery) prepareQuery(ctx context.Context) error { + for _, inter := range lq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, lq); err != nil { + return err + } + } + } + for _, f := range lq.ctx.Fields { + if !lock.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if lq.path != nil { + prev, err := lq.path(ctx) + if err != nil { + return err + } + lq.sql = prev + } + return nil +} + +func (lq *LockQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Lock, error) { + var ( + nodes = []*Lock{} + _spec = lq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Lock).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Lock{config: lq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, lq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (lq *LockQuery) sqlCount(ctx context.Context) (int, error) { + _spec := lq.querySpec() + _spec.Node.Columns = lq.ctx.Fields + if len(lq.ctx.Fields) > 0 { + _spec.Unique = lq.ctx.Unique != nil && *lq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, lq.driver, _spec) +} + +func (lq *LockQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + _spec.From = lq.sql + if unique := lq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if lq.path != nil { + _spec.Unique = true + } + if fields := lq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for i := range fields { + if fields[i] != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := lq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := lq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := lq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := lq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (lq *LockQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(lq.driver.Dialect()) + t1 := builder.Table(lock.Table) + columns := lq.ctx.Fields + if len(columns) == 0 { + columns = lock.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if lq.sql != nil { + selector = lq.sql + selector.Select(selector.Columns(columns...)...) + } + if lq.ctx.Unique != nil && *lq.ctx.Unique { + selector.Distinct() + } + for _, p := range lq.predicates { + p(selector) + } + for _, p := range lq.order { + p(selector) + } + if offset := lq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := lq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// LockGroupBy is the group-by builder for Lock entities. +type LockGroupBy struct { + selector + build *LockQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (lgb *LockGroupBy) Aggregate(fns ...AggregateFunc) *LockGroupBy { + lgb.fns = append(lgb.fns, fns...) + return lgb +} + +// Scan applies the selector query and scans the result into the given value. +func (lgb *LockGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, lgb.build.ctx, "GroupBy") + if err := lgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockGroupBy](ctx, lgb.build, lgb, lgb.build.inters, v) +} + +func (lgb *LockGroupBy) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(lgb.fns)) + for _, fn := range lgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*lgb.flds)+len(lgb.fns)) + for _, f := range *lgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*lgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := lgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// LockSelect is the builder for selecting fields of Lock entities. +type LockSelect struct { + *LockQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ls *LockSelect) Aggregate(fns ...AggregateFunc) *LockSelect { + ls.fns = append(ls.fns, fns...) + return ls +} + +// Scan applies the selector query and scans the result into the given value. +func (ls *LockSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ls.ctx, "Select") + if err := ls.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*LockQuery, *LockSelect](ctx, ls.LockQuery, ls, ls.inters, v) +} + +func (ls *LockSelect) sqlScan(ctx context.Context, root *LockQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ls.fns)) + for _, fn := range ls.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ls.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ls.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/lock_update.go b/pkg/database/ent/lock_update.go new file mode 100644 index 00000000000..934e68c0762 --- /dev/null +++ b/pkg/database/ent/lock_update.go @@ -0,0 +1,175 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// LockUpdate is the builder for updating Lock entities. +type LockUpdate struct { + config + hooks []Hook + mutation *LockMutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (lu *LockUpdate) Where(ps ...predicate.Lock) *LockUpdate { + lu.mutation.Where(ps...) + return lu +} + +// Mutation returns the LockMutation object of the builder. +func (lu *LockUpdate) Mutation() *LockMutation { + return lu.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (lu *LockUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, lu.sqlSave, lu.mutation, lu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (lu *LockUpdate) SaveX(ctx context.Context) int { + affected, err := lu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (lu *LockUpdate) Exec(ctx context.Context) error { + _, err := lu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (lu *LockUpdate) ExecX(ctx context.Context) { + if err := lu.Exec(ctx); err != nil { + panic(err) + } +} + +func (lu *LockUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + if ps := lu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if n, err = sqlgraph.UpdateNodes(ctx, lu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + lu.mutation.done = true + return n, nil +} + +// LockUpdateOne is the builder for updating a single Lock entity. +type LockUpdateOne struct { + config + fields []string + hooks []Hook + mutation *LockMutation +} + +// Mutation returns the LockMutation object of the builder. +func (luo *LockUpdateOne) Mutation() *LockMutation { + return luo.mutation +} + +// Where appends a list predicates to the LockUpdate builder. +func (luo *LockUpdateOne) Where(ps ...predicate.Lock) *LockUpdateOne { + luo.mutation.Where(ps...) + return luo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (luo *LockUpdateOne) Select(field string, fields ...string) *LockUpdateOne { + luo.fields = append([]string{field}, fields...) + return luo +} + +// Save executes the query and returns the updated Lock entity. +func (luo *LockUpdateOne) Save(ctx context.Context) (*Lock, error) { + return withHooks(ctx, luo.sqlSave, luo.mutation, luo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (luo *LockUpdateOne) SaveX(ctx context.Context) *Lock { + node, err := luo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (luo *LockUpdateOne) Exec(ctx context.Context) error { + _, err := luo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (luo *LockUpdateOne) ExecX(ctx context.Context) { + if err := luo.Exec(ctx); err != nil { + panic(err) + } +} + +func (luo *LockUpdateOne) sqlSave(ctx context.Context) (_node *Lock, err error) { + _spec := sqlgraph.NewUpdateSpec(lock.Table, lock.Columns, sqlgraph.NewFieldSpec(lock.FieldID, field.TypeInt)) + id, ok := luo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Lock.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := luo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, lock.FieldID) + for _, f := range fields { + if !lock.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != lock.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := luo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + _node = &Lock{config: luo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, luo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{lock.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + luo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/machine.go b/pkg/database/ent/machine.go index dc2b18ee81c..76127065791 100644 --- a/pkg/database/ent/machine.go +++ b/pkg/database/ent/machine.go @@ -3,12 +3,15 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // Machine is the model entity for the Machine schema. @@ -17,9 +20,9 @@ type Machine struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // LastPush holds the value of the "last_push" field. LastPush *time.Time `json:"last_push,omitempty"` // LastHeartbeat holds the value of the "last_heartbeat" field. @@ -36,13 +39,22 @@ type Machine struct { Version string `json:"version,omitempty"` // IsValidated holds the value of the "isValidated" field. IsValidated bool `json:"isValidated,omitempty"` - // Status holds the value of the "status" field. - Status string `json:"status,omitempty"` // AuthType holds the value of the "auth_type" field. AuthType string `json:"auth_type"` + // Osname holds the value of the "osname" field. + Osname string `json:"osname,omitempty"` + // Osversion holds the value of the "osversion" field. + Osversion string `json:"osversion,omitempty"` + // Featureflags holds the value of the "featureflags" field. + Featureflags string `json:"featureflags,omitempty"` + // Hubstate holds the value of the "hubstate" field. + Hubstate map[string][]schema.ItemState `json:"hubstate,omitempty"` + // Datasources holds the value of the "datasources" field. + Datasources map[string]int64 `json:"datasources,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MachineQuery when eager-loading is set. - Edges MachineEdges `json:"edges"` + Edges MachineEdges `json:"edges"` + selectValues sql.SelectValues } // MachineEdges holds the relations/edges for other nodes in the graph. @@ -68,16 +80,18 @@ func (*Machine) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case machine.FieldHubstate, machine.FieldDatasources: + values[i] = new([]byte) case machine.FieldIsValidated: values[i] = new(sql.NullBool) case machine.FieldID: values[i] = new(sql.NullInt64) - case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus, machine.FieldAuthType: + case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldAuthType, machine.FieldOsname, machine.FieldOsversion, machine.FieldFeatureflags: values[i] = new(sql.NullString) case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Machine", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -101,15 +115,13 @@ func (m *Machine) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - m.CreatedAt = new(time.Time) - *m.CreatedAt = value.Time + m.CreatedAt = value.Time } case machine.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - m.UpdatedAt = new(time.Time) - *m.UpdatedAt = value.Time + m.UpdatedAt = value.Time } case machine.FieldLastPush: if value, ok := values[i].(*sql.NullTime); !ok { @@ -161,33 +173,69 @@ func (m *Machine) assignValues(columns []string, values []any) error { } else if value.Valid { m.IsValidated = value.Bool } - case machine.FieldStatus: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field status", values[i]) - } else if value.Valid { - m.Status = value.String - } case machine.FieldAuthType: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field auth_type", values[i]) } else if value.Valid { m.AuthType = value.String } + case machine.FieldOsname: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osname", values[i]) + } else if value.Valid { + m.Osname = value.String + } + case machine.FieldOsversion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field osversion", values[i]) + } else if value.Valid { + m.Osversion = value.String + } + case machine.FieldFeatureflags: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field featureflags", values[i]) + } else if value.Valid { + m.Featureflags = value.String + } + case machine.FieldHubstate: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field hubstate", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &m.Hubstate); err != nil { + return fmt.Errorf("unmarshal field hubstate: %w", err) + } + } + case machine.FieldDatasources: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field datasources", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &m.Datasources); err != nil { + return fmt.Errorf("unmarshal field datasources: %w", err) + } + } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Machine. +// This includes values selected through modifiers, order, etc. +func (m *Machine) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryAlerts queries the "alerts" edge of the Machine entity. func (m *Machine) QueryAlerts() *AlertQuery { - return (&MachineClient{config: m.config}).QueryAlerts(m) + return NewMachineClient(m.config).QueryAlerts(m) } // Update returns a builder for updating this Machine. // Note that you need to call Machine.Unwrap() before calling this method if this Machine // was returned from a transaction, and the transaction was committed or rolled back. func (m *Machine) Update() *MachineUpdateOne { - return (&MachineClient{config: m.config}).UpdateOne(m) + return NewMachineClient(m.config).UpdateOne(m) } // Unwrap unwraps the Machine entity that was returned from a transaction after it was closed, @@ -206,15 +254,11 @@ func (m *Machine) String() string { var builder strings.Builder builder.WriteString("Machine(") builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) - if v := m.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := m.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") if v := m.LastPush; v != nil { builder.WriteString("last_push=") @@ -243,20 +287,26 @@ func (m *Machine) String() string { builder.WriteString("isValidated=") builder.WriteString(fmt.Sprintf("%v", m.IsValidated)) builder.WriteString(", ") - builder.WriteString("status=") - builder.WriteString(m.Status) - builder.WriteString(", ") builder.WriteString("auth_type=") builder.WriteString(m.AuthType) + builder.WriteString(", ") + builder.WriteString("osname=") + builder.WriteString(m.Osname) + builder.WriteString(", ") + builder.WriteString("osversion=") + builder.WriteString(m.Osversion) + builder.WriteString(", ") + builder.WriteString("featureflags=") + builder.WriteString(m.Featureflags) + builder.WriteString(", ") + builder.WriteString("hubstate=") + builder.WriteString(fmt.Sprintf("%v", m.Hubstate)) + builder.WriteString(", ") + builder.WriteString("datasources=") + builder.WriteString(fmt.Sprintf("%v", m.Datasources)) builder.WriteByte(')') return builder.String() } // Machines is a parsable slice of Machine. type Machines []*Machine - -func (m Machines) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/machine/machine.go b/pkg/database/ent/machine/machine.go index e6900dd21e1..009e6e19c35 100644 --- a/pkg/database/ent/machine/machine.go +++ b/pkg/database/ent/machine/machine.go @@ -4,6 +4,9 @@ package machine import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -31,10 +34,18 @@ const ( FieldVersion = "version" // FieldIsValidated holds the string denoting the isvalidated field in the database. FieldIsValidated = "is_validated" - // FieldStatus holds the string denoting the status field in the database. - FieldStatus = "status" // FieldAuthType holds the string denoting the auth_type field in the database. FieldAuthType = "auth_type" + // FieldOsname holds the string denoting the osname field in the database. + FieldOsname = "osname" + // FieldOsversion holds the string denoting the osversion field in the database. + FieldOsversion = "osversion" + // FieldFeatureflags holds the string denoting the featureflags field in the database. + FieldFeatureflags = "featureflags" + // FieldHubstate holds the string denoting the hubstate field in the database. + FieldHubstate = "hubstate" + // FieldDatasources holds the string denoting the datasources field in the database. + FieldDatasources = "datasources" // EdgeAlerts holds the string denoting the alerts edge name in mutations. EdgeAlerts = "alerts" // Table holds the table name of the machine in the database. @@ -61,8 +72,12 @@ var Columns = []string{ FieldScenarios, FieldVersion, FieldIsValidated, - FieldStatus, FieldAuthType, + FieldOsname, + FieldOsversion, + FieldFeatureflags, + FieldHubstate, + FieldDatasources, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -78,20 +93,12 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. UpdateDefaultUpdatedAt func() time.Time // DefaultLastPush holds the default value on creation for the "last_push" field. DefaultLastPush func() time.Time - // UpdateDefaultLastPush holds the default value on update for the "last_push" field. - UpdateDefaultLastPush func() time.Time - // DefaultLastHeartbeat holds the default value on creation for the "last_heartbeat" field. - DefaultLastHeartbeat func() time.Time - // UpdateDefaultLastHeartbeat holds the default value on update for the "last_heartbeat" field. - UpdateDefaultLastHeartbeat func() time.Time // ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. ScenariosValidator func(string) error // DefaultIsValidated holds the default value on creation for the "isValidated" field. @@ -99,3 +106,102 @@ var ( // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string ) + +// OrderOption defines the ordering options for the Machine queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByLastPush orders the results by the last_push field. +func ByLastPush(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastPush, opts...).ToFunc() +} + +// ByLastHeartbeat orders the results by the last_heartbeat field. +func ByLastHeartbeat(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastHeartbeat, opts...).ToFunc() +} + +// ByMachineId orders the results by the machineId field. +func ByMachineId(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMachineId, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByIpAddress orders the results by the ipAddress field. +func ByIpAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIpAddress, opts...).ToFunc() +} + +// ByScenarios orders the results by the scenarios field. +func ByScenarios(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScenarios, opts...).ToFunc() +} + +// ByVersion orders the results by the version field. +func ByVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVersion, opts...).ToFunc() +} + +// ByIsValidated orders the results by the isValidated field. +func ByIsValidated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsValidated, opts...).ToFunc() +} + +// ByAuthType orders the results by the auth_type field. +func ByAuthType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthType, opts...).ToFunc() +} + +// ByOsname orders the results by the osname field. +func ByOsname(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsname, opts...).ToFunc() +} + +// ByOsversion orders the results by the osversion field. +func ByOsversion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOsversion, opts...).ToFunc() +} + +// ByFeatureflags orders the results by the featureflags field. +func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() +} + +// ByAlertsCount orders the results by alerts count. +func ByAlertsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAlertsStep(), opts...) + } +} + +// ByAlerts orders the results by alerts terms. +func ByAlerts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAlertsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAlertsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AlertsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), + ) +} diff --git a/pkg/database/ent/machine/where.go b/pkg/database/ent/machine/where.go index 7d0227731cc..de523510f33 100644 --- a/pkg/database/ent/machine/where.go +++ b/pkg/database/ent/machine/where.go @@ -12,1218 +12,962 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Machine(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // LastPush applies equality check predicate on the "last_push" field. It's identical to LastPushEQ. func LastPush(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastHeartbeat applies equality check predicate on the "last_heartbeat" field. It's identical to LastHeartbeatEQ. func LastHeartbeat(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // MachineId applies equality check predicate on the "machineId" field. It's identical to MachineIdEQ. func MachineId(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. func Password(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // IpAddress applies equality check predicate on the "ipAddress" field. It's identical to IpAddressEQ. func IpAddress(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // Scenarios applies equality check predicate on the "scenarios" field. It's identical to ScenariosEQ. func Scenarios(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // Version applies equality check predicate on the "version" field. It's identical to VersionEQ. func Version(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // IsValidated applies equality check predicate on the "isValidated" field. It's identical to IsValidatedEQ. func IsValidated(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) -} - -// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. -func Status(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. func AuthType(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) +} + +// Osname applies equality check predicate on the "osname" field. It's identical to OsnameEQ. +func Osname(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsname, v)) +} + +// Osversion applies equality check predicate on the "osversion" field. It's identical to OsversionEQ. +func Osversion(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsversion, v)) +} + +// Featureflags applies equality check predicate on the "featureflags" field. It's identical to FeatureflagsEQ. +func Featureflags(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldFeatureflags, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Machine(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Machine(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Machine(sql.FieldLTE(FieldUpdatedAt, v)) } // LastPushEQ applies the EQ predicate on the "last_push" field. func LastPushEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastPush, v)) } // LastPushNEQ applies the NEQ predicate on the "last_push" field. func LastPushNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastPush, v)) } // LastPushIn applies the In predicate on the "last_push" field. func LastPushIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastPush, vs...)) } // LastPushNotIn applies the NotIn predicate on the "last_push" field. func LastPushNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastPush), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastPush, vs...)) } // LastPushGT applies the GT predicate on the "last_push" field. func LastPushGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastPush, v)) } // LastPushGTE applies the GTE predicate on the "last_push" field. func LastPushGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastPush, v)) } // LastPushLT applies the LT predicate on the "last_push" field. func LastPushLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastPush, v)) } // LastPushLTE applies the LTE predicate on the "last_push" field. func LastPushLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastPush), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastPush, v)) } // LastPushIsNil applies the IsNil predicate on the "last_push" field. func LastPushIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastPush)) } // LastPushNotNil applies the NotNil predicate on the "last_push" field. func LastPushNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastPush))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastPush)) } // LastHeartbeatEQ applies the EQ predicate on the "last_heartbeat" field. func LastHeartbeatEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldLastHeartbeat, v)) } // LastHeartbeatNEQ applies the NEQ predicate on the "last_heartbeat" field. func LastHeartbeatNEQ(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldLastHeartbeat, v)) } // LastHeartbeatIn applies the In predicate on the "last_heartbeat" field. func LastHeartbeatIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatNotIn applies the NotIn predicate on the "last_heartbeat" field. func LastHeartbeatNotIn(vs ...time.Time) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldLastHeartbeat), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldLastHeartbeat, vs...)) } // LastHeartbeatGT applies the GT predicate on the "last_heartbeat" field. func LastHeartbeatGT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGT(FieldLastHeartbeat, v)) } // LastHeartbeatGTE applies the GTE predicate on the "last_heartbeat" field. func LastHeartbeatGTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldLastHeartbeat, v)) } // LastHeartbeatLT applies the LT predicate on the "last_heartbeat" field. func LastHeartbeatLT(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLT(FieldLastHeartbeat, v)) } // LastHeartbeatLTE applies the LTE predicate on the "last_heartbeat" field. func LastHeartbeatLTE(v time.Time) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldLastHeartbeat), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldLastHeartbeat, v)) } // LastHeartbeatIsNil applies the IsNil predicate on the "last_heartbeat" field. func LastHeartbeatIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldIsNull(FieldLastHeartbeat)) } // LastHeartbeatNotNil applies the NotNil predicate on the "last_heartbeat" field. func LastHeartbeatNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldLastHeartbeat))) - }) + return predicate.Machine(sql.FieldNotNull(FieldLastHeartbeat)) } // MachineIdEQ applies the EQ predicate on the "machineId" field. func MachineIdEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldMachineId, v)) } // MachineIdNEQ applies the NEQ predicate on the "machineId" field. func MachineIdNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldMachineId, v)) } // MachineIdIn applies the In predicate on the "machineId" field. func MachineIdIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldMachineId, vs...)) } // MachineIdNotIn applies the NotIn predicate on the "machineId" field. func MachineIdNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldMachineId), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldMachineId, vs...)) } // MachineIdGT applies the GT predicate on the "machineId" field. func MachineIdGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGT(FieldMachineId, v)) } // MachineIdGTE applies the GTE predicate on the "machineId" field. func MachineIdGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldMachineId, v)) } // MachineIdLT applies the LT predicate on the "machineId" field. func MachineIdLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLT(FieldMachineId, v)) } // MachineIdLTE applies the LTE predicate on the "machineId" field. func MachineIdLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldMachineId, v)) } // MachineIdContains applies the Contains predicate on the "machineId" field. func MachineIdContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContains(FieldMachineId, v)) } // MachineIdHasPrefix applies the HasPrefix predicate on the "machineId" field. func MachineIdHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldMachineId, v)) } // MachineIdHasSuffix applies the HasSuffix predicate on the "machineId" field. func MachineIdHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldMachineId, v)) } // MachineIdEqualFold applies the EqualFold predicate on the "machineId" field. func MachineIdEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldMachineId, v)) } // MachineIdContainsFold applies the ContainsFold predicate on the "machineId" field. func MachineIdContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldMachineId), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldMachineId, v)) } // PasswordEQ applies the EQ predicate on the "password" field. func PasswordEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldPassword, v)) } // PasswordNEQ applies the NEQ predicate on the "password" field. func PasswordNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldPassword, v)) } // PasswordIn applies the In predicate on the "password" field. func PasswordIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldPassword, vs...)) } // PasswordNotIn applies the NotIn predicate on the "password" field. func PasswordNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldPassword), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldPassword, vs...)) } // PasswordGT applies the GT predicate on the "password" field. func PasswordGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGT(FieldPassword, v)) } // PasswordGTE applies the GTE predicate on the "password" field. func PasswordGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldPassword, v)) } // PasswordLT applies the LT predicate on the "password" field. func PasswordLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLT(FieldPassword, v)) } // PasswordLTE applies the LTE predicate on the "password" field. func PasswordLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldPassword, v)) } // PasswordContains applies the Contains predicate on the "password" field. func PasswordContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContains(FieldPassword, v)) } // PasswordHasPrefix applies the HasPrefix predicate on the "password" field. func PasswordHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldPassword, v)) } // PasswordHasSuffix applies the HasSuffix predicate on the "password" field. func PasswordHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldPassword, v)) } // PasswordEqualFold applies the EqualFold predicate on the "password" field. func PasswordEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldPassword, v)) } // PasswordContainsFold applies the ContainsFold predicate on the "password" field. func PasswordContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldPassword), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldPassword, v)) } // IpAddressEQ applies the EQ predicate on the "ipAddress" field. func IpAddressEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIpAddress, v)) } // IpAddressNEQ applies the NEQ predicate on the "ipAddress" field. func IpAddressNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIpAddress, v)) } // IpAddressIn applies the In predicate on the "ipAddress" field. func IpAddressIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldIpAddress, vs...)) } // IpAddressNotIn applies the NotIn predicate on the "ipAddress" field. func IpAddressNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldIpAddress), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldIpAddress, vs...)) } // IpAddressGT applies the GT predicate on the "ipAddress" field. func IpAddressGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGT(FieldIpAddress, v)) } // IpAddressGTE applies the GTE predicate on the "ipAddress" field. func IpAddressGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldIpAddress, v)) } // IpAddressLT applies the LT predicate on the "ipAddress" field. func IpAddressLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLT(FieldIpAddress, v)) } // IpAddressLTE applies the LTE predicate on the "ipAddress" field. func IpAddressLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldIpAddress, v)) } // IpAddressContains applies the Contains predicate on the "ipAddress" field. func IpAddressContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContains(FieldIpAddress, v)) } // IpAddressHasPrefix applies the HasPrefix predicate on the "ipAddress" field. func IpAddressHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldIpAddress, v)) } // IpAddressHasSuffix applies the HasSuffix predicate on the "ipAddress" field. func IpAddressHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldIpAddress, v)) } // IpAddressEqualFold applies the EqualFold predicate on the "ipAddress" field. func IpAddressEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldIpAddress, v)) } // IpAddressContainsFold applies the ContainsFold predicate on the "ipAddress" field. func IpAddressContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldIpAddress), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldIpAddress, v)) } // ScenariosEQ applies the EQ predicate on the "scenarios" field. func ScenariosEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldScenarios, v)) } // ScenariosNEQ applies the NEQ predicate on the "scenarios" field. func ScenariosNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldScenarios, v)) } // ScenariosIn applies the In predicate on the "scenarios" field. func ScenariosIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldScenarios, vs...)) } // ScenariosNotIn applies the NotIn predicate on the "scenarios" field. func ScenariosNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldScenarios), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldScenarios, vs...)) } // ScenariosGT applies the GT predicate on the "scenarios" field. func ScenariosGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGT(FieldScenarios, v)) } // ScenariosGTE applies the GTE predicate on the "scenarios" field. func ScenariosGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldScenarios, v)) } // ScenariosLT applies the LT predicate on the "scenarios" field. func ScenariosLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLT(FieldScenarios, v)) } // ScenariosLTE applies the LTE predicate on the "scenarios" field. func ScenariosLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldScenarios, v)) } // ScenariosContains applies the Contains predicate on the "scenarios" field. func ScenariosContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContains(FieldScenarios, v)) } // ScenariosHasPrefix applies the HasPrefix predicate on the "scenarios" field. func ScenariosHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldScenarios, v)) } // ScenariosHasSuffix applies the HasSuffix predicate on the "scenarios" field. func ScenariosHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldScenarios, v)) } // ScenariosIsNil applies the IsNil predicate on the "scenarios" field. func ScenariosIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldIsNull(FieldScenarios)) } // ScenariosNotNil applies the NotNil predicate on the "scenarios" field. func ScenariosNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldScenarios))) - }) + return predicate.Machine(sql.FieldNotNull(FieldScenarios)) } // ScenariosEqualFold applies the EqualFold predicate on the "scenarios" field. func ScenariosEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldScenarios, v)) } // ScenariosContainsFold applies the ContainsFold predicate on the "scenarios" field. func ScenariosContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldScenarios), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldScenarios, v)) } // VersionEQ applies the EQ predicate on the "version" field. func VersionEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldVersion, v)) } // VersionNEQ applies the NEQ predicate on the "version" field. func VersionNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldVersion, v)) } // VersionIn applies the In predicate on the "version" field. func VersionIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldIn(FieldVersion, vs...)) } // VersionNotIn applies the NotIn predicate on the "version" field. func VersionNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldVersion), v...)) - }) + return predicate.Machine(sql.FieldNotIn(FieldVersion, vs...)) } // VersionGT applies the GT predicate on the "version" field. func VersionGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGT(FieldVersion, v)) } // VersionGTE applies the GTE predicate on the "version" field. func VersionGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldGTE(FieldVersion, v)) } // VersionLT applies the LT predicate on the "version" field. func VersionLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLT(FieldVersion, v)) } // VersionLTE applies the LTE predicate on the "version" field. func VersionLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldLTE(FieldVersion, v)) } // VersionContains applies the Contains predicate on the "version" field. func VersionContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContains(FieldVersion, v)) } // VersionHasPrefix applies the HasPrefix predicate on the "version" field. func VersionHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasPrefix(FieldVersion, v)) } // VersionHasSuffix applies the HasSuffix predicate on the "version" field. func VersionHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldHasSuffix(FieldVersion, v)) } // VersionIsNil applies the IsNil predicate on the "version" field. func VersionIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldIsNull(FieldVersion)) } // VersionNotNil applies the NotNil predicate on the "version" field. func VersionNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldVersion))) - }) + return predicate.Machine(sql.FieldNotNull(FieldVersion)) } // VersionEqualFold applies the EqualFold predicate on the "version" field. func VersionEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldEqualFold(FieldVersion, v)) } // VersionContainsFold applies the ContainsFold predicate on the "version" field. func VersionContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldVersion), v)) - }) + return predicate.Machine(sql.FieldContainsFold(FieldVersion, v)) } // IsValidatedEQ applies the EQ predicate on the "isValidated" field. func IsValidatedEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldEQ(FieldIsValidated, v)) } // IsValidatedNEQ applies the NEQ predicate on the "isValidated" field. func IsValidatedNEQ(v bool) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldIsValidated), v)) - }) + return predicate.Machine(sql.FieldNEQ(FieldIsValidated, v)) } -// StatusEQ applies the EQ predicate on the "status" field. -func StatusEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldStatus), v)) - }) +// AuthTypeEQ applies the EQ predicate on the "auth_type" field. +func AuthTypeEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldAuthType, v)) } -// StatusNEQ applies the NEQ predicate on the "status" field. -func StatusNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldStatus), v)) - }) +// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. +func AuthTypeNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldAuthType, v)) } -// StatusIn applies the In predicate on the "status" field. -func StatusIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldStatus), v...)) - }) +// AuthTypeIn applies the In predicate on the "auth_type" field. +func AuthTypeIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldAuthType, vs...)) } -// StatusNotIn applies the NotIn predicate on the "status" field. -func StatusNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldStatus), v...)) - }) +// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. +func AuthTypeNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldAuthType, vs...)) } -// StatusGT applies the GT predicate on the "status" field. -func StatusGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldStatus), v)) - }) +// AuthTypeGT applies the GT predicate on the "auth_type" field. +func AuthTypeGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldAuthType, v)) } -// StatusGTE applies the GTE predicate on the "status" field. -func StatusGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldStatus), v)) - }) +// AuthTypeGTE applies the GTE predicate on the "auth_type" field. +func AuthTypeGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldAuthType, v)) } -// StatusLT applies the LT predicate on the "status" field. -func StatusLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldStatus), v)) - }) +// AuthTypeLT applies the LT predicate on the "auth_type" field. +func AuthTypeLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldAuthType, v)) } -// StatusLTE applies the LTE predicate on the "status" field. -func StatusLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldStatus), v)) - }) +// AuthTypeLTE applies the LTE predicate on the "auth_type" field. +func AuthTypeLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldAuthType, v)) } -// StatusContains applies the Contains predicate on the "status" field. -func StatusContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldStatus), v)) - }) +// AuthTypeContains applies the Contains predicate on the "auth_type" field. +func AuthTypeContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldAuthType, v)) } -// StatusHasPrefix applies the HasPrefix predicate on the "status" field. -func StatusHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldStatus), v)) - }) +// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. +func AuthTypeHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldAuthType, v)) } -// StatusHasSuffix applies the HasSuffix predicate on the "status" field. -func StatusHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldStatus), v)) - }) +// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. +func AuthTypeHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldAuthType, v)) } -// StatusIsNil applies the IsNil predicate on the "status" field. -func StatusIsNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldStatus))) - }) +// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. +func AuthTypeEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldAuthType, v)) } -// StatusNotNil applies the NotNil predicate on the "status" field. -func StatusNotNil() predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldStatus))) - }) +// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. +func AuthTypeContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldAuthType, v)) } -// StatusEqualFold applies the EqualFold predicate on the "status" field. -func StatusEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldStatus), v)) - }) +// OsnameEQ applies the EQ predicate on the "osname" field. +func OsnameEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsname, v)) } -// StatusContainsFold applies the ContainsFold predicate on the "status" field. -func StatusContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldStatus), v)) - }) +// OsnameNEQ applies the NEQ predicate on the "osname" field. +func OsnameNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldOsname, v)) } -// AuthTypeEQ applies the EQ predicate on the "auth_type" field. -func AuthTypeEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAuthType), v)) - }) +// OsnameIn applies the In predicate on the "osname" field. +func OsnameIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldOsname, vs...)) } -// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. -func AuthTypeNEQ(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAuthType), v)) - }) +// OsnameNotIn applies the NotIn predicate on the "osname" field. +func OsnameNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldOsname, vs...)) } -// AuthTypeIn applies the In predicate on the "auth_type" field. -func AuthTypeIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAuthType), v...)) - }) +// OsnameGT applies the GT predicate on the "osname" field. +func OsnameGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldOsname, v)) } -// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. -func AuthTypeNotIn(vs ...string) predicate.Machine { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAuthType), v...)) - }) +// OsnameGTE applies the GTE predicate on the "osname" field. +func OsnameGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldOsname, v)) } -// AuthTypeGT applies the GT predicate on the "auth_type" field. -func AuthTypeGT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAuthType), v)) - }) +// OsnameLT applies the LT predicate on the "osname" field. +func OsnameLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldOsname, v)) } -// AuthTypeGTE applies the GTE predicate on the "auth_type" field. -func AuthTypeGTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAuthType), v)) - }) +// OsnameLTE applies the LTE predicate on the "osname" field. +func OsnameLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldOsname, v)) } -// AuthTypeLT applies the LT predicate on the "auth_type" field. -func AuthTypeLT(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAuthType), v)) - }) +// OsnameContains applies the Contains predicate on the "osname" field. +func OsnameContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldOsname, v)) } -// AuthTypeLTE applies the LTE predicate on the "auth_type" field. -func AuthTypeLTE(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAuthType), v)) - }) +// OsnameHasPrefix applies the HasPrefix predicate on the "osname" field. +func OsnameHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldOsname, v)) } -// AuthTypeContains applies the Contains predicate on the "auth_type" field. -func AuthTypeContains(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAuthType), v)) - }) +// OsnameHasSuffix applies the HasSuffix predicate on the "osname" field. +func OsnameHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldOsname, v)) } -// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. -func AuthTypeHasPrefix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) - }) +// OsnameIsNil applies the IsNil predicate on the "osname" field. +func OsnameIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldOsname)) } -// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. -func AuthTypeHasSuffix(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) - }) +// OsnameNotNil applies the NotNil predicate on the "osname" field. +func OsnameNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldOsname)) } -// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. -func AuthTypeEqualFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAuthType), v)) - }) +// OsnameEqualFold applies the EqualFold predicate on the "osname" field. +func OsnameEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldOsname, v)) } -// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. -func AuthTypeContainsFold(v string) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) - }) +// OsnameContainsFold applies the ContainsFold predicate on the "osname" field. +func OsnameContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldOsname, v)) +} + +// OsversionEQ applies the EQ predicate on the "osversion" field. +func OsversionEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldOsversion, v)) +} + +// OsversionNEQ applies the NEQ predicate on the "osversion" field. +func OsversionNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldOsversion, v)) +} + +// OsversionIn applies the In predicate on the "osversion" field. +func OsversionIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldOsversion, vs...)) +} + +// OsversionNotIn applies the NotIn predicate on the "osversion" field. +func OsversionNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldOsversion, vs...)) +} + +// OsversionGT applies the GT predicate on the "osversion" field. +func OsversionGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldOsversion, v)) +} + +// OsversionGTE applies the GTE predicate on the "osversion" field. +func OsversionGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldOsversion, v)) +} + +// OsversionLT applies the LT predicate on the "osversion" field. +func OsversionLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldOsversion, v)) +} + +// OsversionLTE applies the LTE predicate on the "osversion" field. +func OsversionLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldOsversion, v)) +} + +// OsversionContains applies the Contains predicate on the "osversion" field. +func OsversionContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldOsversion, v)) +} + +// OsversionHasPrefix applies the HasPrefix predicate on the "osversion" field. +func OsversionHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldOsversion, v)) +} + +// OsversionHasSuffix applies the HasSuffix predicate on the "osversion" field. +func OsversionHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldOsversion, v)) +} + +// OsversionIsNil applies the IsNil predicate on the "osversion" field. +func OsversionIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldOsversion)) +} + +// OsversionNotNil applies the NotNil predicate on the "osversion" field. +func OsversionNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldOsversion)) +} + +// OsversionEqualFold applies the EqualFold predicate on the "osversion" field. +func OsversionEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldOsversion, v)) +} + +// OsversionContainsFold applies the ContainsFold predicate on the "osversion" field. +func OsversionContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldOsversion, v)) +} + +// FeatureflagsEQ applies the EQ predicate on the "featureflags" field. +func FeatureflagsEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldEQ(FieldFeatureflags, v)) +} + +// FeatureflagsNEQ applies the NEQ predicate on the "featureflags" field. +func FeatureflagsNEQ(v string) predicate.Machine { + return predicate.Machine(sql.FieldNEQ(FieldFeatureflags, v)) +} + +// FeatureflagsIn applies the In predicate on the "featureflags" field. +func FeatureflagsIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsNotIn applies the NotIn predicate on the "featureflags" field. +func FeatureflagsNotIn(vs ...string) predicate.Machine { + return predicate.Machine(sql.FieldNotIn(FieldFeatureflags, vs...)) +} + +// FeatureflagsGT applies the GT predicate on the "featureflags" field. +func FeatureflagsGT(v string) predicate.Machine { + return predicate.Machine(sql.FieldGT(FieldFeatureflags, v)) +} + +// FeatureflagsGTE applies the GTE predicate on the "featureflags" field. +func FeatureflagsGTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldGTE(FieldFeatureflags, v)) +} + +// FeatureflagsLT applies the LT predicate on the "featureflags" field. +func FeatureflagsLT(v string) predicate.Machine { + return predicate.Machine(sql.FieldLT(FieldFeatureflags, v)) +} + +// FeatureflagsLTE applies the LTE predicate on the "featureflags" field. +func FeatureflagsLTE(v string) predicate.Machine { + return predicate.Machine(sql.FieldLTE(FieldFeatureflags, v)) +} + +// FeatureflagsContains applies the Contains predicate on the "featureflags" field. +func FeatureflagsContains(v string) predicate.Machine { + return predicate.Machine(sql.FieldContains(FieldFeatureflags, v)) +} + +// FeatureflagsHasPrefix applies the HasPrefix predicate on the "featureflags" field. +func FeatureflagsHasPrefix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasPrefix(FieldFeatureflags, v)) +} + +// FeatureflagsHasSuffix applies the HasSuffix predicate on the "featureflags" field. +func FeatureflagsHasSuffix(v string) predicate.Machine { + return predicate.Machine(sql.FieldHasSuffix(FieldFeatureflags, v)) +} + +// FeatureflagsIsNil applies the IsNil predicate on the "featureflags" field. +func FeatureflagsIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldFeatureflags)) +} + +// FeatureflagsNotNil applies the NotNil predicate on the "featureflags" field. +func FeatureflagsNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldFeatureflags)) +} + +// FeatureflagsEqualFold applies the EqualFold predicate on the "featureflags" field. +func FeatureflagsEqualFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldEqualFold(FieldFeatureflags, v)) +} + +// FeatureflagsContainsFold applies the ContainsFold predicate on the "featureflags" field. +func FeatureflagsContainsFold(v string) predicate.Machine { + return predicate.Machine(sql.FieldContainsFold(FieldFeatureflags, v)) +} + +// HubstateIsNil applies the IsNil predicate on the "hubstate" field. +func HubstateIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldHubstate)) +} + +// HubstateNotNil applies the NotNil predicate on the "hubstate" field. +func HubstateNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldHubstate)) +} + +// DatasourcesIsNil applies the IsNil predicate on the "datasources" field. +func DatasourcesIsNil() predicate.Machine { + return predicate.Machine(sql.FieldIsNull(FieldDatasources)) +} + +// DatasourcesNotNil applies the NotNil predicate on the "datasources" field. +func DatasourcesNotNil() predicate.Machine { + return predicate.Machine(sql.FieldNotNull(FieldDatasources)) } // HasAlerts applies the HasEdge predicate on the "alerts" edge. @@ -1231,7 +975,6 @@ func HasAlerts() predicate.Machine { return predicate.Machine(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -1241,11 +984,7 @@ func HasAlerts() predicate.Machine { // HasAlertsWith applies the HasEdge predicate on the "alerts" edge with a given conditions (other predicates). func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { return predicate.Machine(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AlertsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, AlertsTable, AlertsColumn), - ) + step := newAlertsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -1256,32 +995,15 @@ func HasAlertsWith(preds ...predicate.Alert) predicate.Machine { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Machine(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Machine) predicate.Machine { - return predicate.Machine(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Machine(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/machine_create.go b/pkg/database/ent/machine_create.go index efe02782f6b..fba8400798c 100644 --- a/pkg/database/ent/machine_create.go +++ b/pkg/database/ent/machine_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/schema/field" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // MachineCreate is the builder for creating a Machine entity. @@ -137,34 +138,74 @@ func (mc *MachineCreate) SetNillableIsValidated(b *bool) *MachineCreate { return mc } -// SetStatus sets the "status" field. -func (mc *MachineCreate) SetStatus(s string) *MachineCreate { - mc.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (mc *MachineCreate) SetAuthType(s string) *MachineCreate { + mc.mutation.SetAuthType(s) return mc } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (mc *MachineCreate) SetNillableStatus(s *string) *MachineCreate { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate { if s != nil { - mc.SetStatus(*s) + mc.SetAuthType(*s) } return mc } -// SetAuthType sets the "auth_type" field. -func (mc *MachineCreate) SetAuthType(s string) *MachineCreate { - mc.mutation.SetAuthType(s) +// SetOsname sets the "osname" field. +func (mc *MachineCreate) SetOsname(s string) *MachineCreate { + mc.mutation.SetOsname(s) return mc } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate { +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (mc *MachineCreate) SetNillableOsname(s *string) *MachineCreate { if s != nil { - mc.SetAuthType(*s) + mc.SetOsname(*s) } return mc } +// SetOsversion sets the "osversion" field. +func (mc *MachineCreate) SetOsversion(s string) *MachineCreate { + mc.mutation.SetOsversion(s) + return mc +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (mc *MachineCreate) SetNillableOsversion(s *string) *MachineCreate { + if s != nil { + mc.SetOsversion(*s) + } + return mc +} + +// SetFeatureflags sets the "featureflags" field. +func (mc *MachineCreate) SetFeatureflags(s string) *MachineCreate { + mc.mutation.SetFeatureflags(s) + return mc +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (mc *MachineCreate) SetNillableFeatureflags(s *string) *MachineCreate { + if s != nil { + mc.SetFeatureflags(*s) + } + return mc +} + +// SetHubstate sets the "hubstate" field. +func (mc *MachineCreate) SetHubstate(ms map[string][]schema.ItemState) *MachineCreate { + mc.mutation.SetHubstate(ms) + return mc +} + +// SetDatasources sets the "datasources" field. +func (mc *MachineCreate) SetDatasources(m map[string]int64) *MachineCreate { + mc.mutation.SetDatasources(m) + return mc +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mc *MachineCreate) AddAlertIDs(ids ...int) *MachineCreate { mc.mutation.AddAlertIDs(ids...) @@ -187,50 +228,8 @@ func (mc *MachineCreate) Mutation() *MachineMutation { // Save creates the Machine in the database. func (mc *MachineCreate) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -269,10 +268,6 @@ func (mc *MachineCreate) defaults() { v := machine.DefaultLastPush() mc.mutation.SetLastPush(v) } - if _, ok := mc.mutation.LastHeartbeat(); !ok { - v := machine.DefaultLastHeartbeat() - mc.mutation.SetLastHeartbeat(v) - } if _, ok := mc.mutation.IsValidated(); !ok { v := machine.DefaultIsValidated mc.mutation.SetIsValidated(v) @@ -285,6 +280,12 @@ func (mc *MachineCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MachineCreate) check() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Machine.created_at"`)} + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Machine.updated_at"`)} + } if _, ok := mc.mutation.MachineId(); !ok { return &ValidationError{Name: "machineId", err: errors.New(`ent: missing required field "Machine.machineId"`)} } @@ -309,6 +310,9 @@ func (mc *MachineCreate) check() error { } func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -318,116 +322,80 @@ func (mc *MachineCreate) sqlSave(ctx context.Context) (*Machine, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { var ( _node = &Machine{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(machine.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := mc.mutation.LastPush(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) _node.LastPush = &value } if value, ok := mc.mutation.LastHeartbeat(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) _node.LastHeartbeat = &value } if value, ok := mc.mutation.MachineId(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.SetField(machine.FieldMachineId, field.TypeString, value) _node.MachineId = value } if value, ok := mc.mutation.Password(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) _node.Password = value } if value, ok := mc.mutation.IpAddress(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) _node.IpAddress = value } if value, ok := mc.mutation.Scenarios(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) _node.Scenarios = value } if value, ok := mc.mutation.Version(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) _node.Version = value } if value, ok := mc.mutation.IsValidated(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) _node.IsValidated = value } - if value, ok := mc.mutation.Status(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - _node.Status = value - } if value, ok := mc.mutation.AuthType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) _node.AuthType = value } + if value, ok := mc.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + _node.Osname = value + } + if value, ok := mc.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + _node.Osversion = value + } + if value, ok := mc.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + _node.Featureflags = value + } + if value, ok := mc.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + _node.Hubstate = value + } + if value, ok := mc.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + _node.Datasources = value + } if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -436,10 +404,7 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -453,11 +418,15 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { // MachineCreateBulk is the builder for creating many Machine entities in bulk. type MachineCreateBulk struct { config + err error builders []*MachineCreate } // Save creates the Machine entities in the database. func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Machine, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -474,8 +443,8 @@ func (mcb *MachineCreateBulk) Save(ctx context.Context) ([]*Machine, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/machine_delete.go b/pkg/database/ent/machine_delete.go index bead8acb46d..ac3aa751d5e 100644 --- a/pkg/database/ent/machine_delete.go +++ b/pkg/database/ent/machine_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MachineDelete) Where(ps ...predicate.Machine) *MachineDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MachineDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MachineDelete) ExecX(ctx context.Context) int { } func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(machine.Table, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MachineDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MachineDeleteOne struct { md *MachineDelete } +// Where appends a list predicates to the MachineDelete builder. +func (mdo *MachineDeleteOne) Where(ps ...predicate.Machine) *MachineDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MachineDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MachineDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/machine_query.go b/pkg/database/ent/machine_query.go index 2839142196b..462c2cf35b1 100644 --- a/pkg/database/ent/machine_query.go +++ b/pkg/database/ent/machine_query.go @@ -19,11 +19,9 @@ import ( // MachineQuery is the builder for querying Machine entities. type MachineQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []machine.OrderOption + inters []Interceptor predicates []predicate.Machine withAlerts *AlertQuery // intermediate query (i.e. traversal path). @@ -37,34 +35,34 @@ func (mq *MachineQuery) Where(ps ...predicate.Machine) *MachineQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MachineQuery) Limit(limit int) *MachineQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MachineQuery) Offset(offset int) *MachineQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MachineQuery) Unique(unique bool) *MachineQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MachineQuery) Order(o ...OrderFunc) *MachineQuery { +// Order specifies how the records should be ordered. +func (mq *MachineQuery) Order(o ...machine.OrderOption) *MachineQuery { mq.order = append(mq.order, o...) return mq } // QueryAlerts chains the current query on the "alerts" edge. func (mq *MachineQuery) QueryAlerts() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -87,7 +85,7 @@ func (mq *MachineQuery) QueryAlerts() *AlertQuery { // First returns the first Machine entity from the query. // Returns a *NotFoundError when no Machine was found. func (mq *MachineQuery) First(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -110,7 +108,7 @@ func (mq *MachineQuery) FirstX(ctx context.Context) *Machine { // Returns a *NotFoundError when no Machine ID was found. func (mq *MachineQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -133,7 +131,7 @@ func (mq *MachineQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Machine entity is found. // Returns a *NotFoundError when no Machine entities are found. func (mq *MachineQuery) Only(ctx context.Context) (*Machine, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -161,7 +159,7 @@ func (mq *MachineQuery) OnlyX(ctx context.Context) *Machine { // Returns a *NotFoundError when no entities are found. func (mq *MachineQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -186,10 +184,12 @@ func (mq *MachineQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Machines. func (mq *MachineQuery) All(ctx context.Context) ([]*Machine, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Machine, *MachineQuery]() + return withInterceptors[[]*Machine](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -202,9 +202,12 @@ func (mq *MachineQuery) AllX(ctx context.Context) []*Machine { } // IDs executes the query and returns a list of Machine IDs. -func (mq *MachineQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MachineQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(machine.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -221,10 +224,11 @@ func (mq *MachineQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MachineQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MachineQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -238,10 +242,15 @@ func (mq *MachineQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MachineQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -261,22 +270,21 @@ func (mq *MachineQuery) Clone() *MachineQuery { } return &MachineQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]machine.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Machine{}, mq.predicates...), withAlerts: mq.withAlerts.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithAlerts tells the query-builder to eager-load the nodes that are connected to // the "alerts" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -299,16 +307,11 @@ func (mq *MachineQuery) WithAlerts(opts ...func(*AlertQuery)) *MachineQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy { - grbuild := &MachineGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MachineGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = machine.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -325,15 +328,30 @@ func (mq *MachineQuery) GroupBy(field string, fields ...string) *MachineGroupBy // Select(machine.FieldCreatedAt). // Scan(ctx, &v) func (mq *MachineQuery) Select(fields ...string) *MachineSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MachineSelect{MachineQuery: mq} - selbuild.label = machine.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MachineSelect{MachineQuery: mq} + sbuild.label = machine.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MachineSelect configured with the given aggregations. +func (mq *MachineQuery) Aggregate(fns ...AggregateFunc) *MachineSelect { + return mq.Select().Aggregate(fns...) } func (mq *MachineQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !machine.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -396,7 +414,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } query.withFKs = true query.Where(predicate.Alert(func(s *sql.Selector) { - s.Where(sql.InValues(machine.AlertsColumn, fks...)) + s.Where(sql.InValues(s.C(machine.AlertsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -409,7 +427,7 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "machine_alerts" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -418,41 +436,22 @@ func (mq *MachineQuery) loadAlerts(ctx context.Context, query *AlertQuery, nodes func (mq *MachineQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MachineQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, machine.FieldID) for i := range fields { @@ -468,10 +467,10 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -487,7 +486,7 @@ func (mq *MachineQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(machine.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = machine.Columns } @@ -496,7 +495,7 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -505,12 +504,12 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -518,13 +517,8 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector { // MachineGroupBy is the group-by builder for Machine entities. type MachineGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MachineQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -533,74 +527,77 @@ func (mgb *MachineGroupBy) Aggregate(fns ...AggregateFunc) *MachineGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MachineGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MachineGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !machine.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MachineGroupBy) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MachineGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MachineSelect is the builder for selecting fields of Machine entities. type MachineSelect struct { *MachineQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MachineSelect) Aggregate(fns ...AggregateFunc) *MachineSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MachineSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MachineQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MachineQuery, *MachineSelect](ctx, ms.MachineQuery, ms, ms.inters, v) } -func (ms *MachineSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MachineSelect) sqlScan(ctx context.Context, root *MachineQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/machine_update.go b/pkg/database/ent/machine_update.go index de9f8d12460..531baabf0d6 100644 --- a/pkg/database/ent/machine_update.go +++ b/pkg/database/ent/machine_update.go @@ -14,6 +14,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) // MachineUpdate is the builder for updating Machine entities. @@ -29,36 +30,26 @@ func (mu *MachineUpdate) Where(ps ...predicate.Machine) *MachineUpdate { return mu } -// SetCreatedAt sets the "created_at" field. -func (mu *MachineUpdate) SetCreatedAt(t time.Time) *MachineUpdate { - mu.mutation.SetCreatedAt(t) - return mu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (mu *MachineUpdate) ClearCreatedAt() *MachineUpdate { - mu.mutation.ClearCreatedAt() - return mu -} - // SetUpdatedAt sets the "updated_at" field. func (mu *MachineUpdate) SetUpdatedAt(t time.Time) *MachineUpdate { mu.mutation.SetUpdatedAt(t) return mu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (mu *MachineUpdate) ClearUpdatedAt() *MachineUpdate { - mu.mutation.ClearUpdatedAt() - return mu -} - // SetLastPush sets the "last_push" field. func (mu *MachineUpdate) SetLastPush(t time.Time) *MachineUpdate { mu.mutation.SetLastPush(t) return mu } +// SetNillableLastPush sets the "last_push" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableLastPush(t *time.Time) *MachineUpdate { + if t != nil { + mu.SetLastPush(*t) + } + return mu +} + // ClearLastPush clears the value of the "last_push" field. func (mu *MachineUpdate) ClearLastPush() *MachineUpdate { mu.mutation.ClearLastPush() @@ -71,15 +62,17 @@ func (mu *MachineUpdate) SetLastHeartbeat(t time.Time) *MachineUpdate { return mu } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (mu *MachineUpdate) ClearLastHeartbeat() *MachineUpdate { - mu.mutation.ClearLastHeartbeat() +// SetNillableLastHeartbeat sets the "last_heartbeat" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableLastHeartbeat(t *time.Time) *MachineUpdate { + if t != nil { + mu.SetLastHeartbeat(*t) + } return mu } -// SetMachineId sets the "machineId" field. -func (mu *MachineUpdate) SetMachineId(s string) *MachineUpdate { - mu.mutation.SetMachineId(s) +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (mu *MachineUpdate) ClearLastHeartbeat() *MachineUpdate { + mu.mutation.ClearLastHeartbeat() return mu } @@ -89,12 +82,28 @@ func (mu *MachineUpdate) SetPassword(s string) *MachineUpdate { return mu } +// SetNillablePassword sets the "password" field if the given value is not nil. +func (mu *MachineUpdate) SetNillablePassword(s *string) *MachineUpdate { + if s != nil { + mu.SetPassword(*s) + } + return mu +} + // SetIpAddress sets the "ipAddress" field. func (mu *MachineUpdate) SetIpAddress(s string) *MachineUpdate { mu.mutation.SetIpAddress(s) return mu } +// SetNillableIpAddress sets the "ipAddress" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableIpAddress(s *string) *MachineUpdate { + if s != nil { + mu.SetIpAddress(*s) + } + return mu +} + // SetScenarios sets the "scenarios" field. func (mu *MachineUpdate) SetScenarios(s string) *MachineUpdate { mu.mutation.SetScenarios(s) @@ -149,40 +158,104 @@ func (mu *MachineUpdate) SetNillableIsValidated(b *bool) *MachineUpdate { return mu } -// SetStatus sets the "status" field. -func (mu *MachineUpdate) SetStatus(s string) *MachineUpdate { - mu.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate { + mu.mutation.SetAuthType(s) return mu } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (mu *MachineUpdate) SetNillableStatus(s *string) *MachineUpdate { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate { if s != nil { - mu.SetStatus(*s) + mu.SetAuthType(*s) } return mu } -// ClearStatus clears the value of the "status" field. -func (mu *MachineUpdate) ClearStatus() *MachineUpdate { - mu.mutation.ClearStatus() +// SetOsname sets the "osname" field. +func (mu *MachineUpdate) SetOsname(s string) *MachineUpdate { + mu.mutation.SetOsname(s) return mu } -// SetAuthType sets the "auth_type" field. -func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate { - mu.mutation.SetAuthType(s) +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableOsname(s *string) *MachineUpdate { + if s != nil { + mu.SetOsname(*s) + } return mu } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate { +// ClearOsname clears the value of the "osname" field. +func (mu *MachineUpdate) ClearOsname() *MachineUpdate { + mu.mutation.ClearOsname() + return mu +} + +// SetOsversion sets the "osversion" field. +func (mu *MachineUpdate) SetOsversion(s string) *MachineUpdate { + mu.mutation.SetOsversion(s) + return mu +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableOsversion(s *string) *MachineUpdate { if s != nil { - mu.SetAuthType(*s) + mu.SetOsversion(*s) + } + return mu +} + +// ClearOsversion clears the value of the "osversion" field. +func (mu *MachineUpdate) ClearOsversion() *MachineUpdate { + mu.mutation.ClearOsversion() + return mu +} + +// SetFeatureflags sets the "featureflags" field. +func (mu *MachineUpdate) SetFeatureflags(s string) *MachineUpdate { + mu.mutation.SetFeatureflags(s) + return mu +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableFeatureflags(s *string) *MachineUpdate { + if s != nil { + mu.SetFeatureflags(*s) } return mu } +// ClearFeatureflags clears the value of the "featureflags" field. +func (mu *MachineUpdate) ClearFeatureflags() *MachineUpdate { + mu.mutation.ClearFeatureflags() + return mu +} + +// SetHubstate sets the "hubstate" field. +func (mu *MachineUpdate) SetHubstate(ms map[string][]schema.ItemState) *MachineUpdate { + mu.mutation.SetHubstate(ms) + return mu +} + +// ClearHubstate clears the value of the "hubstate" field. +func (mu *MachineUpdate) ClearHubstate() *MachineUpdate { + mu.mutation.ClearHubstate() + return mu +} + +// SetDatasources sets the "datasources" field. +func (mu *MachineUpdate) SetDatasources(m map[string]int64) *MachineUpdate { + mu.mutation.SetDatasources(m) + return mu +} + +// ClearDatasources clears the value of the "datasources" field. +func (mu *MachineUpdate) ClearDatasources() *MachineUpdate { + mu.mutation.ClearDatasources() + return mu +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mu *MachineUpdate) AddAlertIDs(ids ...int) *MachineUpdate { mu.mutation.AddAlertIDs(ids...) @@ -226,41 +299,8 @@ func (mu *MachineUpdate) RemoveAlerts(a ...*Alert) *MachineUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MachineUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -287,22 +327,10 @@ func (mu *MachineUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (mu *MachineUpdate) defaults() { - if _, ok := mu.mutation.CreatedAt(); !ok && !mu.mutation.CreatedAtCleared() { - v := machine.UpdateDefaultCreatedAt() - mu.mutation.SetCreatedAt(v) - } - if _, ok := mu.mutation.UpdatedAt(); !ok && !mu.mutation.UpdatedAtCleared() { + if _, ok := mu.mutation.UpdatedAt(); !ok { v := machine.UpdateDefaultUpdatedAt() mu.mutation.SetUpdatedAt(v) } - if _, ok := mu.mutation.LastPush(); !ok && !mu.mutation.LastPushCleared() { - v := machine.UpdateDefaultLastPush() - mu.mutation.SetLastPush(v) - } - if _, ok := mu.mutation.LastHeartbeat(); !ok && !mu.mutation.LastHeartbeatCleared() { - v := machine.UpdateDefaultLastHeartbeat() - mu.mutation.SetLastHeartbeat(v) - } } // check runs all checks and user-defined validators on the builder. @@ -316,16 +344,10 @@ func (mu *MachineUpdate) check() error { } func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := mu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -333,131 +355,74 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - } - if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) - } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - } - if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if value, ok := mu.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if mu.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := mu.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if mu.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) - } - if value, ok := mu.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := mu.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := mu.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := mu.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if mu.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := mu.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if mu.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := mu.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) - } - if value, ok := mu.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - } - if mu.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := mu.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) + } + if value, ok := mu.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + } + if mu.mutation.OsnameCleared() { + _spec.ClearField(machine.FieldOsname, field.TypeString) + } + if value, ok := mu.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + } + if mu.mutation.OsversionCleared() { + _spec.ClearField(machine.FieldOsversion, field.TypeString) + } + if value, ok := mu.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + } + if mu.mutation.FeatureflagsCleared() { + _spec.ClearField(machine.FieldFeatureflags, field.TypeString) + } + if value, ok := mu.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + } + if mu.mutation.HubstateCleared() { + _spec.ClearField(machine.FieldHubstate, field.TypeJSON) + } + if value, ok := mu.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + } + if mu.mutation.DatasourcesCleared() { + _spec.ClearField(machine.FieldDatasources, field.TypeJSON) } if mu.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -467,10 +432,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -483,10 +445,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -502,10 +461,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -521,6 +477,7 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -532,36 +489,26 @@ type MachineUpdateOne struct { mutation *MachineMutation } -// SetCreatedAt sets the "created_at" field. -func (muo *MachineUpdateOne) SetCreatedAt(t time.Time) *MachineUpdateOne { - muo.mutation.SetCreatedAt(t) - return muo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (muo *MachineUpdateOne) ClearCreatedAt() *MachineUpdateOne { - muo.mutation.ClearCreatedAt() - return muo -} - // SetUpdatedAt sets the "updated_at" field. func (muo *MachineUpdateOne) SetUpdatedAt(t time.Time) *MachineUpdateOne { muo.mutation.SetUpdatedAt(t) return muo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (muo *MachineUpdateOne) ClearUpdatedAt() *MachineUpdateOne { - muo.mutation.ClearUpdatedAt() - return muo -} - // SetLastPush sets the "last_push" field. func (muo *MachineUpdateOne) SetLastPush(t time.Time) *MachineUpdateOne { muo.mutation.SetLastPush(t) return muo } +// SetNillableLastPush sets the "last_push" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableLastPush(t *time.Time) *MachineUpdateOne { + if t != nil { + muo.SetLastPush(*t) + } + return muo +} + // ClearLastPush clears the value of the "last_push" field. func (muo *MachineUpdateOne) ClearLastPush() *MachineUpdateOne { muo.mutation.ClearLastPush() @@ -574,15 +521,17 @@ func (muo *MachineUpdateOne) SetLastHeartbeat(t time.Time) *MachineUpdateOne { return muo } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (muo *MachineUpdateOne) ClearLastHeartbeat() *MachineUpdateOne { - muo.mutation.ClearLastHeartbeat() +// SetNillableLastHeartbeat sets the "last_heartbeat" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableLastHeartbeat(t *time.Time) *MachineUpdateOne { + if t != nil { + muo.SetLastHeartbeat(*t) + } return muo } -// SetMachineId sets the "machineId" field. -func (muo *MachineUpdateOne) SetMachineId(s string) *MachineUpdateOne { - muo.mutation.SetMachineId(s) +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (muo *MachineUpdateOne) ClearLastHeartbeat() *MachineUpdateOne { + muo.mutation.ClearLastHeartbeat() return muo } @@ -592,12 +541,28 @@ func (muo *MachineUpdateOne) SetPassword(s string) *MachineUpdateOne { return muo } +// SetNillablePassword sets the "password" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillablePassword(s *string) *MachineUpdateOne { + if s != nil { + muo.SetPassword(*s) + } + return muo +} + // SetIpAddress sets the "ipAddress" field. func (muo *MachineUpdateOne) SetIpAddress(s string) *MachineUpdateOne { muo.mutation.SetIpAddress(s) return muo } +// SetNillableIpAddress sets the "ipAddress" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableIpAddress(s *string) *MachineUpdateOne { + if s != nil { + muo.SetIpAddress(*s) + } + return muo +} + // SetScenarios sets the "scenarios" field. func (muo *MachineUpdateOne) SetScenarios(s string) *MachineUpdateOne { muo.mutation.SetScenarios(s) @@ -652,40 +617,104 @@ func (muo *MachineUpdateOne) SetNillableIsValidated(b *bool) *MachineUpdateOne { return muo } -// SetStatus sets the "status" field. -func (muo *MachineUpdateOne) SetStatus(s string) *MachineUpdateOne { - muo.mutation.SetStatus(s) +// SetAuthType sets the "auth_type" field. +func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne { + muo.mutation.SetAuthType(s) return muo } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (muo *MachineUpdateOne) SetNillableStatus(s *string) *MachineUpdateOne { +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne { if s != nil { - muo.SetStatus(*s) + muo.SetAuthType(*s) } return muo } -// ClearStatus clears the value of the "status" field. -func (muo *MachineUpdateOne) ClearStatus() *MachineUpdateOne { - muo.mutation.ClearStatus() +// SetOsname sets the "osname" field. +func (muo *MachineUpdateOne) SetOsname(s string) *MachineUpdateOne { + muo.mutation.SetOsname(s) return muo } -// SetAuthType sets the "auth_type" field. -func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne { - muo.mutation.SetAuthType(s) +// SetNillableOsname sets the "osname" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableOsname(s *string) *MachineUpdateOne { + if s != nil { + muo.SetOsname(*s) + } return muo } -// SetNillableAuthType sets the "auth_type" field if the given value is not nil. -func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne { +// ClearOsname clears the value of the "osname" field. +func (muo *MachineUpdateOne) ClearOsname() *MachineUpdateOne { + muo.mutation.ClearOsname() + return muo +} + +// SetOsversion sets the "osversion" field. +func (muo *MachineUpdateOne) SetOsversion(s string) *MachineUpdateOne { + muo.mutation.SetOsversion(s) + return muo +} + +// SetNillableOsversion sets the "osversion" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableOsversion(s *string) *MachineUpdateOne { if s != nil { - muo.SetAuthType(*s) + muo.SetOsversion(*s) } return muo } +// ClearOsversion clears the value of the "osversion" field. +func (muo *MachineUpdateOne) ClearOsversion() *MachineUpdateOne { + muo.mutation.ClearOsversion() + return muo +} + +// SetFeatureflags sets the "featureflags" field. +func (muo *MachineUpdateOne) SetFeatureflags(s string) *MachineUpdateOne { + muo.mutation.SetFeatureflags(s) + return muo +} + +// SetNillableFeatureflags sets the "featureflags" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableFeatureflags(s *string) *MachineUpdateOne { + if s != nil { + muo.SetFeatureflags(*s) + } + return muo +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (muo *MachineUpdateOne) ClearFeatureflags() *MachineUpdateOne { + muo.mutation.ClearFeatureflags() + return muo +} + +// SetHubstate sets the "hubstate" field. +func (muo *MachineUpdateOne) SetHubstate(ms map[string][]schema.ItemState) *MachineUpdateOne { + muo.mutation.SetHubstate(ms) + return muo +} + +// ClearHubstate clears the value of the "hubstate" field. +func (muo *MachineUpdateOne) ClearHubstate() *MachineUpdateOne { + muo.mutation.ClearHubstate() + return muo +} + +// SetDatasources sets the "datasources" field. +func (muo *MachineUpdateOne) SetDatasources(m map[string]int64) *MachineUpdateOne { + muo.mutation.SetDatasources(m) + return muo +} + +// ClearDatasources clears the value of the "datasources" field. +func (muo *MachineUpdateOne) ClearDatasources() *MachineUpdateOne { + muo.mutation.ClearDatasources() + return muo +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (muo *MachineUpdateOne) AddAlertIDs(ids ...int) *MachineUpdateOne { muo.mutation.AddAlertIDs(ids...) @@ -727,6 +756,12 @@ func (muo *MachineUpdateOne) RemoveAlerts(a ...*Alert) *MachineUpdateOne { return muo.RemoveAlertIDs(ids...) } +// Where appends a list predicates to the MachineUpdate builder. +func (muo *MachineUpdateOne) Where(ps ...predicate.Machine) *MachineUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpdateOne { @@ -736,47 +771,8 @@ func (muo *MachineUpdateOne) Select(field string, fields ...string) *MachineUpda // Save executes the query and returns the updated Machine entity. func (muo *MachineUpdateOne) Save(ctx context.Context) (*Machine, error) { - var ( - err error - node *Machine - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MachineMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Machine) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MachineMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -803,22 +799,10 @@ func (muo *MachineUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (muo *MachineUpdateOne) defaults() { - if _, ok := muo.mutation.CreatedAt(); !ok && !muo.mutation.CreatedAtCleared() { - v := machine.UpdateDefaultCreatedAt() - muo.mutation.SetCreatedAt(v) - } - if _, ok := muo.mutation.UpdatedAt(); !ok && !muo.mutation.UpdatedAtCleared() { + if _, ok := muo.mutation.UpdatedAt(); !ok { v := machine.UpdateDefaultUpdatedAt() muo.mutation.SetUpdatedAt(v) } - if _, ok := muo.mutation.LastPush(); !ok && !muo.mutation.LastPushCleared() { - v := machine.UpdateDefaultLastPush() - muo.mutation.SetLastPush(v) - } - if _, ok := muo.mutation.LastHeartbeat(); !ok && !muo.mutation.LastHeartbeatCleared() { - v := machine.UpdateDefaultLastHeartbeat() - muo.mutation.SetLastHeartbeat(v) - } } // check runs all checks and user-defined validators on the builder. @@ -832,16 +816,10 @@ func (muo *MachineUpdateOne) check() error { } func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: machine.Table, - Columns: machine.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: machine.FieldID, - }, - }, + if err := muo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(machine.Table, machine.Columns, sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Machine.id" for update`)} @@ -866,131 +844,74 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } } } - if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldCreatedAt, - }) - } - if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldCreatedAt, - }) - } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldUpdatedAt, - }) - } - if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldUpdatedAt, - }) + _spec.SetField(machine.FieldUpdatedAt, field.TypeTime, value) } if value, ok := muo.mutation.LastPush(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastPush, - }) + _spec.SetField(machine.FieldLastPush, field.TypeTime, value) } if muo.mutation.LastPushCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastPush, - }) + _spec.ClearField(machine.FieldLastPush, field.TypeTime) } if value, ok := muo.mutation.LastHeartbeat(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: machine.FieldLastHeartbeat, - }) + _spec.SetField(machine.FieldLastHeartbeat, field.TypeTime, value) } if muo.mutation.LastHeartbeatCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: machine.FieldLastHeartbeat, - }) - } - if value, ok := muo.mutation.MachineId(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldMachineId, - }) + _spec.ClearField(machine.FieldLastHeartbeat, field.TypeTime) } if value, ok := muo.mutation.Password(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldPassword, - }) + _spec.SetField(machine.FieldPassword, field.TypeString, value) } if value, ok := muo.mutation.IpAddress(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldIpAddress, - }) + _spec.SetField(machine.FieldIpAddress, field.TypeString, value) } if value, ok := muo.mutation.Scenarios(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldScenarios, - }) + _spec.SetField(machine.FieldScenarios, field.TypeString, value) } if muo.mutation.ScenariosCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldScenarios, - }) + _spec.ClearField(machine.FieldScenarios, field.TypeString) } if value, ok := muo.mutation.Version(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldVersion, - }) + _spec.SetField(machine.FieldVersion, field.TypeString, value) } if muo.mutation.VersionCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldVersion, - }) + _spec.ClearField(machine.FieldVersion, field.TypeString) } if value, ok := muo.mutation.IsValidated(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: machine.FieldIsValidated, - }) - } - if value, ok := muo.mutation.Status(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldStatus, - }) - } - if muo.mutation.StatusCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: machine.FieldStatus, - }) + _spec.SetField(machine.FieldIsValidated, field.TypeBool, value) } if value, ok := muo.mutation.AuthType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: machine.FieldAuthType, - }) + _spec.SetField(machine.FieldAuthType, field.TypeString, value) + } + if value, ok := muo.mutation.Osname(); ok { + _spec.SetField(machine.FieldOsname, field.TypeString, value) + } + if muo.mutation.OsnameCleared() { + _spec.ClearField(machine.FieldOsname, field.TypeString) + } + if value, ok := muo.mutation.Osversion(); ok { + _spec.SetField(machine.FieldOsversion, field.TypeString, value) + } + if muo.mutation.OsversionCleared() { + _spec.ClearField(machine.FieldOsversion, field.TypeString) + } + if value, ok := muo.mutation.Featureflags(); ok { + _spec.SetField(machine.FieldFeatureflags, field.TypeString, value) + } + if muo.mutation.FeatureflagsCleared() { + _spec.ClearField(machine.FieldFeatureflags, field.TypeString) + } + if value, ok := muo.mutation.Hubstate(); ok { + _spec.SetField(machine.FieldHubstate, field.TypeJSON, value) + } + if muo.mutation.HubstateCleared() { + _spec.ClearField(machine.FieldHubstate, field.TypeJSON) + } + if value, ok := muo.mutation.Datasources(); ok { + _spec.SetField(machine.FieldDatasources, field.TypeJSON, value) + } + if muo.mutation.DatasourcesCleared() { + _spec.ClearField(machine.FieldDatasources, field.TypeJSON) } if muo.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1000,10 +921,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -1016,10 +934,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1035,10 +950,7 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Columns: []string{machine.AlertsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -1057,5 +969,6 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/meta.go b/pkg/database/ent/meta.go index 660f1a4db73..7e29627957c 100644 --- a/pkg/database/ent/meta.go +++ b/pkg/database/ent/meta.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" @@ -18,9 +19,9 @@ type Meta struct { // ID of the ent. ID int `json:"id,omitempty"` // CreatedAt holds the value of the "created_at" field. - CreatedAt *time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt *time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` // Key holds the value of the "key" field. Key string `json:"key,omitempty"` // Value holds the value of the "value" field. @@ -29,7 +30,8 @@ type Meta struct { AlertMetas int `json:"alert_metas,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MetaQuery when eager-loading is set. - Edges MetaEdges `json:"edges"` + Edges MetaEdges `json:"edges"` + selectValues sql.SelectValues } // MetaEdges holds the relations/edges for other nodes in the graph. @@ -44,12 +46,10 @@ type MetaEdges struct { // OwnerOrErr returns the Owner value or an error if the edge // was not loaded in eager-loading, or loaded but was not found. func (e MetaEdges) OwnerOrErr() (*Alert, error) { - if e.loadedTypes[0] { - if e.Owner == nil { - // Edge was loaded but was not found. - return nil, &NotFoundError{label: alert.Label} - } + if e.Owner != nil { return e.Owner, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: alert.Label} } return nil, &NotLoadedError{edge: "owner"} } @@ -66,7 +66,7 @@ func (*Meta) scanValues(columns []string) ([]any, error) { case meta.FieldCreatedAt, meta.FieldUpdatedAt: values[i] = new(sql.NullTime) default: - return nil, fmt.Errorf("unexpected column %q for type Meta", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -90,15 +90,13 @@ func (m *Meta) assignValues(columns []string, values []any) error { if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) } else if value.Valid { - m.CreatedAt = new(time.Time) - *m.CreatedAt = value.Time + m.CreatedAt = value.Time } case meta.FieldUpdatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field updated_at", values[i]) } else if value.Valid { - m.UpdatedAt = new(time.Time) - *m.UpdatedAt = value.Time + m.UpdatedAt = value.Time } case meta.FieldKey: if value, ok := values[i].(*sql.NullString); !ok { @@ -118,21 +116,29 @@ func (m *Meta) assignValues(columns []string, values []any) error { } else if value.Valid { m.AlertMetas = int(value.Int64) } + default: + m.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the Meta. +// This includes values selected through modifiers, order, etc. +func (m *Meta) GetValue(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + // QueryOwner queries the "owner" edge of the Meta entity. func (m *Meta) QueryOwner() *AlertQuery { - return (&MetaClient{config: m.config}).QueryOwner(m) + return NewMetaClient(m.config).QueryOwner(m) } // Update returns a builder for updating this Meta. // Note that you need to call Meta.Unwrap() before calling this method if this Meta // was returned from a transaction, and the transaction was committed or rolled back. func (m *Meta) Update() *MetaUpdateOne { - return (&MetaClient{config: m.config}).UpdateOne(m) + return NewMetaClient(m.config).UpdateOne(m) } // Unwrap unwraps the Meta entity that was returned from a transaction after it was closed, @@ -151,15 +157,11 @@ func (m *Meta) String() string { var builder strings.Builder builder.WriteString("Meta(") builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) - if v := m.CreatedAt; v != nil { - builder.WriteString("created_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("created_at=") + builder.WriteString(m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") - if v := m.UpdatedAt; v != nil { - builder.WriteString("updated_at=") - builder.WriteString(v.Format(time.ANSIC)) - } + builder.WriteString("updated_at=") + builder.WriteString(m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("key=") builder.WriteString(m.Key) @@ -175,9 +177,3 @@ func (m *Meta) String() string { // MetaSlice is a parsable slice of Meta. type MetaSlice []*Meta - -func (m MetaSlice) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/pkg/database/ent/meta/meta.go b/pkg/database/ent/meta/meta.go index 6d10f258919..ff41361616a 100644 --- a/pkg/database/ent/meta/meta.go +++ b/pkg/database/ent/meta/meta.go @@ -4,6 +4,9 @@ package meta import ( "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -57,8 +60,6 @@ func ValidColumn(column string) bool { var ( // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time - // UpdateDefaultCreatedAt holds the default value on update for the "created_at" field. - UpdateDefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. @@ -66,3 +67,50 @@ var ( // ValueValidator is a validator for the "value" field. It is called by the builders before save. ValueValidator func(string) error ) + +// OrderOption defines the ordering options for the Meta queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByAlertMetas orders the results by the alert_metas field. +func ByAlertMetas(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlertMetas, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/pkg/database/ent/meta/where.go b/pkg/database/ent/meta/where.go index 479792fd4a6..6d5d54c0482 100644 --- a/pkg/database/ent/meta/where.go +++ b/pkg/database/ent/meta/where.go @@ -12,512 +12,312 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Meta(sql.FieldLTE(FieldID, id)) } // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. func UpdatedAt(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // Key applies equality check predicate on the "key" field. It's identical to KeyEQ. func Key(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // AlertMetas applies equality check predicate on the "alert_metas" field. It's identical to AlertMetasEQ. func AlertMetas(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldCreatedAt, v)) } // CreatedAtNEQ applies the NEQ predicate on the "created_at" field. func CreatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldCreatedAt, v)) } // CreatedAtIn applies the In predicate on the "created_at" field. func CreatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldCreatedAt, vs...)) } // CreatedAtNotIn applies the NotIn predicate on the "created_at" field. func CreatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldCreatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldCreatedAt, vs...)) } // CreatedAtGT applies the GT predicate on the "created_at" field. func CreatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldCreatedAt, v)) } // CreatedAtGTE applies the GTE predicate on the "created_at" field. func CreatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldCreatedAt, v)) } // CreatedAtLT applies the LT predicate on the "created_at" field. func CreatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldCreatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldCreatedAt, v)) } // CreatedAtLTE applies the LTE predicate on the "created_at" field. func CreatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldCreatedAt), v)) - }) -} - -// CreatedAtIsNil applies the IsNil predicate on the "created_at" field. -func CreatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldCreatedAt))) - }) -} - -// CreatedAtNotNil applies the NotNil predicate on the "created_at" field. -func CreatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldCreatedAt))) - }) + return predicate.Meta(sql.FieldLTE(FieldCreatedAt, v)) } // UpdatedAtEQ applies the EQ predicate on the "updated_at" field. func UpdatedAtEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldUpdatedAt, v)) } // UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. func UpdatedAtNEQ(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldUpdatedAt, v)) } // UpdatedAtIn applies the In predicate on the "updated_at" field. func UpdatedAtIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldUpdatedAt, vs...)) } // UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. func UpdatedAtNotIn(vs ...time.Time) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldUpdatedAt, vs...)) } // UpdatedAtGT applies the GT predicate on the "updated_at" field. func UpdatedAtGT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGT(FieldUpdatedAt, v)) } // UpdatedAtGTE applies the GTE predicate on the "updated_at" field. func UpdatedAtGTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldUpdatedAt, v)) } // UpdatedAtLT applies the LT predicate on the "updated_at" field. func UpdatedAtLT(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUpdatedAt), v)) - }) + return predicate.Meta(sql.FieldLT(FieldUpdatedAt, v)) } // UpdatedAtLTE applies the LTE predicate on the "updated_at" field. func UpdatedAtLTE(v time.Time) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUpdatedAt), v)) - }) -} - -// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field. -func UpdatedAtIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldUpdatedAt))) - }) -} - -// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field. -func UpdatedAtNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldUpdatedAt))) - }) + return predicate.Meta(sql.FieldLTE(FieldUpdatedAt, v)) } // KeyEQ applies the EQ predicate on the "key" field. func KeyEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldKey, v)) } // KeyNEQ applies the NEQ predicate on the "key" field. func KeyNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldKey, v)) } // KeyIn applies the In predicate on the "key" field. func KeyIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldKey, vs...)) } // KeyNotIn applies the NotIn predicate on the "key" field. func KeyNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldKey), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldKey, vs...)) } // KeyGT applies the GT predicate on the "key" field. func KeyGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGT(FieldKey, v)) } // KeyGTE applies the GTE predicate on the "key" field. func KeyGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldKey, v)) } // KeyLT applies the LT predicate on the "key" field. func KeyLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLT(FieldKey, v)) } // KeyLTE applies the LTE predicate on the "key" field. func KeyLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldKey, v)) } // KeyContains applies the Contains predicate on the "key" field. func KeyContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContains(FieldKey, v)) } // KeyHasPrefix applies the HasPrefix predicate on the "key" field. func KeyHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldKey, v)) } // KeyHasSuffix applies the HasSuffix predicate on the "key" field. func KeyHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldKey, v)) } // KeyEqualFold applies the EqualFold predicate on the "key" field. func KeyEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldKey, v)) } // KeyContainsFold applies the ContainsFold predicate on the "key" field. func KeyContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldKey), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldKey, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.Meta(sql.FieldContainsFold(FieldValue, v)) } // AlertMetasEQ applies the EQ predicate on the "alert_metas" field. func AlertMetasEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldEQ(FieldAlertMetas, v)) } // AlertMetasNEQ applies the NEQ predicate on the "alert_metas" field. func AlertMetasNEQ(v int) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlertMetas), v)) - }) + return predicate.Meta(sql.FieldNEQ(FieldAlertMetas, v)) } // AlertMetasIn applies the In predicate on the "alert_metas" field. func AlertMetasIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldIn(FieldAlertMetas, vs...)) } // AlertMetasNotIn applies the NotIn predicate on the "alert_metas" field. func AlertMetasNotIn(vs ...int) predicate.Meta { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlertMetas), v...)) - }) + return predicate.Meta(sql.FieldNotIn(FieldAlertMetas, vs...)) } // AlertMetasIsNil applies the IsNil predicate on the "alert_metas" field. func AlertMetasIsNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldIsNull(FieldAlertMetas)) } // AlertMetasNotNil applies the NotNil predicate on the "alert_metas" field. func AlertMetasNotNil() predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldAlertMetas))) - }) + return predicate.Meta(sql.FieldNotNull(FieldAlertMetas)) } // HasOwner applies the HasEdge predicate on the "owner" edge. @@ -525,7 +325,6 @@ func HasOwner() predicate.Meta { return predicate.Meta(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), ) sqlgraph.HasNeighbors(s, step) @@ -535,11 +334,7 @@ func HasOwner() predicate.Meta { // HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { return predicate.Meta(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(OwnerInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), - ) + step := newOwnerStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -550,32 +345,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Meta { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Meta(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Meta) predicate.Meta { - return predicate.Meta(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Meta(sql.NotPredicates(p)) } diff --git a/pkg/database/ent/meta_create.go b/pkg/database/ent/meta_create.go index df4f6315911..321c4bd7ab4 100644 --- a/pkg/database/ent/meta_create.go +++ b/pkg/database/ent/meta_create.go @@ -101,50 +101,8 @@ func (mc *MetaCreate) Mutation() *MetaMutation { // Save creates the Meta in the database. func (mc *MetaCreate) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -183,6 +141,12 @@ func (mc *MetaCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MetaCreate) check() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Meta.created_at"`)} + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Meta.updated_at"`)} + } if _, ok := mc.mutation.Key(); !ok { return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "Meta.key"`)} } @@ -198,6 +162,9 @@ func (mc *MetaCreate) check() error { } func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -207,50 +174,30 @@ func (mc *MetaCreate) sqlSave(ctx context.Context) (*Meta, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { var ( _node = &Meta{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) ) if value, ok := mc.mutation.CreatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - _node.CreatedAt = &value + _spec.SetField(meta.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value } if value, ok := mc.mutation.UpdatedAt(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - _node.UpdatedAt = &value + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value } if value, ok := mc.mutation.Key(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) + _spec.SetField(meta.FieldKey, field.TypeString, value) _node.Key = value } if value, ok := mc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldValue, field.TypeString, value) _node.Value = value } if nodes := mc.mutation.OwnerIDs(); len(nodes) > 0 { @@ -261,10 +208,7 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -279,11 +223,15 @@ func (mc *MetaCreate) createSpec() (*Meta, *sqlgraph.CreateSpec) { // MetaCreateBulk is the builder for creating many Meta entities in bulk. type MetaCreateBulk struct { config + err error builders []*MetaCreate } // Save creates the Meta entities in the database. func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { + if mcb.err != nil { + return nil, mcb.err + } specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) nodes := make([]*Meta, len(mcb.builders)) mutators := make([]Mutator, len(mcb.builders)) @@ -300,8 +248,8 @@ func (mcb *MetaCreateBulk) Save(ctx context.Context) ([]*Meta, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) } else { diff --git a/pkg/database/ent/meta_delete.go b/pkg/database/ent/meta_delete.go index e1e49d2acdc..ee25dd07eb9 100644 --- a/pkg/database/ent/meta_delete.go +++ b/pkg/database/ent/meta_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (md *MetaDelete) Where(ps ...predicate.Meta) *MetaDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MetaDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MetaDelete) ExecX(ctx context.Context) int { } func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(meta.Table, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MetaDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MetaDeleteOne struct { md *MetaDelete } +// Where appends a list predicates to the MetaDelete builder. +func (mdo *MetaDeleteOne) Where(ps ...predicate.Meta) *MetaDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MetaDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MetaDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/pkg/database/ent/meta_query.go b/pkg/database/ent/meta_query.go index d6fd4f3d522..87d91d09e0e 100644 --- a/pkg/database/ent/meta_query.go +++ b/pkg/database/ent/meta_query.go @@ -18,11 +18,9 @@ import ( // MetaQuery is the builder for querying Meta entities. type MetaQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []meta.OrderOption + inters []Interceptor predicates []predicate.Meta withOwner *AlertQuery // intermediate query (i.e. traversal path). @@ -36,34 +34,34 @@ func (mq *MetaQuery) Where(ps ...predicate.Meta) *MetaQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MetaQuery) Limit(limit int) *MetaQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MetaQuery) Offset(offset int) *MetaQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MetaQuery) Unique(unique bool) *MetaQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. -func (mq *MetaQuery) Order(o ...OrderFunc) *MetaQuery { +// Order specifies how the records should be ordered. +func (mq *MetaQuery) Order(o ...meta.OrderOption) *MetaQuery { mq.order = append(mq.order, o...) return mq } // QueryOwner chains the current query on the "owner" edge. func (mq *MetaQuery) QueryOwner() *AlertQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -86,7 +84,7 @@ func (mq *MetaQuery) QueryOwner() *AlertQuery { // First returns the first Meta entity from the query. // Returns a *NotFoundError when no Meta was found. func (mq *MetaQuery) First(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -109,7 +107,7 @@ func (mq *MetaQuery) FirstX(ctx context.Context) *Meta { // Returns a *NotFoundError when no Meta ID was found. func (mq *MetaQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -132,7 +130,7 @@ func (mq *MetaQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Meta entity is found. // Returns a *NotFoundError when no Meta entities are found. func (mq *MetaQuery) Only(ctx context.Context) (*Meta, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (mq *MetaQuery) OnlyX(ctx context.Context) *Meta { // Returns a *NotFoundError when no entities are found. func (mq *MetaQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -185,10 +183,12 @@ func (mq *MetaQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MetaSlice. func (mq *MetaQuery) All(ctx context.Context) ([]*Meta, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Meta, *MetaQuery]() + return withInterceptors[[]*Meta](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -201,9 +201,12 @@ func (mq *MetaQuery) AllX(ctx context.Context) []*Meta { } // IDs executes the query and returns a list of Meta IDs. -func (mq *MetaQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MetaQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(meta.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -220,10 +223,11 @@ func (mq *MetaQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mq *MetaQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MetaQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -237,10 +241,15 @@ func (mq *MetaQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MetaQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -260,22 +269,21 @@ func (mq *MetaQuery) Clone() *MetaQuery { } return &MetaQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, - order: append([]OrderFunc{}, mq.order...), + ctx: mq.ctx.Clone(), + order: append([]meta.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Meta{}, mq.predicates...), withOwner: mq.withOwner.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithOwner tells the query-builder to eager-load the nodes that are connected to // the "owner" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { - query := &AlertQuery{config: mq.config} + query := (&AlertClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -298,16 +306,11 @@ func (mq *MetaQuery) WithOwner(opts ...func(*AlertQuery)) *MetaQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { - grbuild := &MetaGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetaGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = meta.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -324,15 +327,30 @@ func (mq *MetaQuery) GroupBy(field string, fields ...string) *MetaGroupBy { // Select(meta.FieldCreatedAt). // Scan(ctx, &v) func (mq *MetaQuery) Select(fields ...string) *MetaSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MetaSelect{MetaQuery: mq} - selbuild.label = meta.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetaSelect{MetaQuery: mq} + sbuild.label = meta.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetaSelect configured with the given aggregations. +func (mq *MetaQuery) Aggregate(fns ...AggregateFunc) *MetaSelect { + return mq.Select().Aggregate(fns...) } func (mq *MetaQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !meta.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -392,6 +410,9 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(alert.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -411,41 +432,22 @@ func (mq *MetaQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []* func (mq *MetaQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MetaQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := mq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, meta.FieldID) for i := range fields { @@ -453,6 +455,9 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if mq.withOwner != nil { + _spec.Node.AddColumnOnce(meta.FieldAlertMetas) + } } if ps := mq.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -461,10 +466,10 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -480,7 +485,7 @@ func (mq *MetaQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(meta.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = meta.Columns } @@ -489,7 +494,7 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -498,12 +503,12 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -511,13 +516,8 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector { // MetaGroupBy is the group-by builder for Meta entities. type MetaGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MetaQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -526,74 +526,77 @@ func (mgb *MetaGroupBy) Aggregate(fns ...AggregateFunc) *MetaGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (mgb *MetaGroupBy) Scan(ctx context.Context, v any) error { - query, err := mgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MetaGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range mgb.fields { - if !meta.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MetaGroupBy) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MetaGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MetaSelect is the builder for selecting fields of Meta entities. type MetaSelect struct { *MetaQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetaSelect) Aggregate(fns ...AggregateFunc) *MetaSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. func (ms *MetaSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MetaQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MetaQuery, *MetaSelect](ctx, ms.MetaQuery, ms, ms.inters, v) } -func (ms *MetaSelect) sqlScan(ctx context.Context, v any) error { +func (ms *MetaSelect) sqlScan(ctx context.Context, root *MetaQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/pkg/database/ent/meta_update.go b/pkg/database/ent/meta_update.go index 67a198dddfa..bdf622eb6c3 100644 --- a/pkg/database/ent/meta_update.go +++ b/pkg/database/ent/meta_update.go @@ -29,42 +29,12 @@ func (mu *MetaUpdate) Where(ps ...predicate.Meta) *MetaUpdate { return mu } -// SetCreatedAt sets the "created_at" field. -func (mu *MetaUpdate) SetCreatedAt(t time.Time) *MetaUpdate { - mu.mutation.SetCreatedAt(t) - return mu -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (mu *MetaUpdate) ClearCreatedAt() *MetaUpdate { - mu.mutation.ClearCreatedAt() - return mu -} - // SetUpdatedAt sets the "updated_at" field. func (mu *MetaUpdate) SetUpdatedAt(t time.Time) *MetaUpdate { mu.mutation.SetUpdatedAt(t) return mu } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (mu *MetaUpdate) ClearUpdatedAt() *MetaUpdate { - mu.mutation.ClearUpdatedAt() - return mu -} - -// SetKey sets the "key" field. -func (mu *MetaUpdate) SetKey(s string) *MetaUpdate { - mu.mutation.SetKey(s) - return mu -} - -// SetValue sets the "value" field. -func (mu *MetaUpdate) SetValue(s string) *MetaUpdate { - mu.mutation.SetValue(s) - return mu -} - // SetAlertMetas sets the "alert_metas" field. func (mu *MetaUpdate) SetAlertMetas(i int) *MetaUpdate { mu.mutation.SetAlertMetas(i) @@ -117,41 +87,8 @@ func (mu *MetaUpdate) ClearOwner() *MetaUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MetaUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) mu.defaults() - if len(mu.hooks) == 0 { - if err = mu.check(); err != nil { - return 0, err - } - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mu.check(); err != nil { - return 0, err - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -178,37 +115,14 @@ func (mu *MetaUpdate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (mu *MetaUpdate) defaults() { - if _, ok := mu.mutation.CreatedAt(); !ok && !mu.mutation.CreatedAtCleared() { - v := meta.UpdateDefaultCreatedAt() - mu.mutation.SetCreatedAt(v) - } - if _, ok := mu.mutation.UpdatedAt(); !ok && !mu.mutation.UpdatedAtCleared() { + if _, ok := mu.mutation.UpdatedAt(); !ok { v := meta.UpdateDefaultUpdatedAt() mu.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (mu *MetaUpdate) check() error { - if v, ok := mu.mutation.Value(); ok { - if err := meta.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Meta.value": %w`, err)} - } - } - return nil -} - func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -216,45 +130,8 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } - if value, ok := mu.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - } - if mu.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) - } if value, ok := mu.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - } - if mu.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) - } - if value, ok := mu.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) - } - if value, ok := mu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if mu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -264,10 +141,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -280,10 +154,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -299,6 +170,7 @@ func (mu *MetaUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -310,42 +182,12 @@ type MetaUpdateOne struct { mutation *MetaMutation } -// SetCreatedAt sets the "created_at" field. -func (muo *MetaUpdateOne) SetCreatedAt(t time.Time) *MetaUpdateOne { - muo.mutation.SetCreatedAt(t) - return muo -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (muo *MetaUpdateOne) ClearCreatedAt() *MetaUpdateOne { - muo.mutation.ClearCreatedAt() - return muo -} - // SetUpdatedAt sets the "updated_at" field. func (muo *MetaUpdateOne) SetUpdatedAt(t time.Time) *MetaUpdateOne { muo.mutation.SetUpdatedAt(t) return muo } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (muo *MetaUpdateOne) ClearUpdatedAt() *MetaUpdateOne { - muo.mutation.ClearUpdatedAt() - return muo -} - -// SetKey sets the "key" field. -func (muo *MetaUpdateOne) SetKey(s string) *MetaUpdateOne { - muo.mutation.SetKey(s) - return muo -} - -// SetValue sets the "value" field. -func (muo *MetaUpdateOne) SetValue(s string) *MetaUpdateOne { - muo.mutation.SetValue(s) - return muo -} - // SetAlertMetas sets the "alert_metas" field. func (muo *MetaUpdateOne) SetAlertMetas(i int) *MetaUpdateOne { muo.mutation.SetAlertMetas(i) @@ -396,6 +238,12 @@ func (muo *MetaUpdateOne) ClearOwner() *MetaUpdateOne { return muo } +// Where appends a list predicates to the MetaUpdate builder. +func (muo *MetaUpdateOne) Where(ps ...predicate.Meta) *MetaUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne { @@ -405,47 +253,8 @@ func (muo *MetaUpdateOne) Select(field string, fields ...string) *MetaUpdateOne // Save executes the query and returns the updated Meta entity. func (muo *MetaUpdateOne) Save(ctx context.Context) (*Meta, error) { - var ( - err error - node *Meta - ) muo.defaults() - if len(muo.hooks) == 0 { - if err = muo.check(); err != nil { - return nil, err - } - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MetaMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = muo.check(); err != nil { - return nil, err - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Meta) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MetaMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -472,37 +281,14 @@ func (muo *MetaUpdateOne) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (muo *MetaUpdateOne) defaults() { - if _, ok := muo.mutation.CreatedAt(); !ok && !muo.mutation.CreatedAtCleared() { - v := meta.UpdateDefaultCreatedAt() - muo.mutation.SetCreatedAt(v) - } - if _, ok := muo.mutation.UpdatedAt(); !ok && !muo.mutation.UpdatedAtCleared() { + if _, ok := muo.mutation.UpdatedAt(); !ok { v := meta.UpdateDefaultUpdatedAt() muo.mutation.SetUpdatedAt(v) } } -// check runs all checks and user-defined validators on the builder. -func (muo *MetaUpdateOne) check() error { - if v, ok := muo.mutation.Value(); ok { - if err := meta.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Meta.value": %w`, err)} - } - } - return nil -} - func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: meta.Table, - Columns: meta.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: meta.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(meta.Table, meta.Columns, sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt)) id, ok := muo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Meta.id" for update`)} @@ -527,45 +313,8 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } } } - if value, ok := muo.mutation.CreatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldCreatedAt, - }) - } - if muo.mutation.CreatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldCreatedAt, - }) - } if value, ok := muo.mutation.UpdatedAt(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: meta.FieldUpdatedAt, - }) - } - if muo.mutation.UpdatedAtCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Column: meta.FieldUpdatedAt, - }) - } - if value, ok := muo.mutation.Key(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldKey, - }) - } - if value, ok := muo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: meta.FieldValue, - }) + _spec.SetField(meta.FieldUpdatedAt, field.TypeTime, value) } if muo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ @@ -575,10 +324,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -591,10 +337,7 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) Columns: []string{meta.OwnerColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: alert.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -613,5 +356,6 @@ func (muo *MetaUpdateOne) sqlSave(ctx context.Context) (_node *Meta, err error) } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/pkg/database/ent/metric.go b/pkg/database/ent/metric.go new file mode 100644 index 00000000000..47f3b4df4e5 --- /dev/null +++ b/pkg/database/ent/metric.go @@ -0,0 +1,154 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +// Metric is the model entity for the Metric schema. +type Metric struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Type of the metrics source: LP=logprocessor, RC=remediation + GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` + // Source of the metrics: machine id, bouncer name... + // It must come from the auth middleware. + GeneratedBy string `json:"generated_by,omitempty"` + // When the metrics are received by LAPI + ReceivedAt time.Time `json:"received_at,omitempty"` + // When the metrics are sent to the console + PushedAt *time.Time `json:"pushed_at,omitempty"` + // The actual metrics (item0) + Payload string `json:"payload,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Metric) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case metric.FieldID: + values[i] = new(sql.NullInt64) + case metric.FieldGeneratedType, metric.FieldGeneratedBy, metric.FieldPayload: + values[i] = new(sql.NullString) + case metric.FieldReceivedAt, metric.FieldPushedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Metric fields. +func (m *Metric) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case metric.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + m.ID = int(value.Int64) + case metric.FieldGeneratedType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field generated_type", values[i]) + } else if value.Valid { + m.GeneratedType = metric.GeneratedType(value.String) + } + case metric.FieldGeneratedBy: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field generated_by", values[i]) + } else if value.Valid { + m.GeneratedBy = value.String + } + case metric.FieldReceivedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field received_at", values[i]) + } else if value.Valid { + m.ReceivedAt = value.Time + } + case metric.FieldPushedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field pushed_at", values[i]) + } else if value.Valid { + m.PushedAt = new(time.Time) + *m.PushedAt = value.Time + } + case metric.FieldPayload: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field payload", values[i]) + } else if value.Valid { + m.Payload = value.String + } + default: + m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Metric. +// This includes values selected through modifiers, order, etc. +func (m *Metric) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + +// Update returns a builder for updating this Metric. +// Note that you need to call Metric.Unwrap() before calling this method if this Metric +// was returned from a transaction, and the transaction was committed or rolled back. +func (m *Metric) Update() *MetricUpdateOne { + return NewMetricClient(m.config).UpdateOne(m) +} + +// Unwrap unwraps the Metric entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (m *Metric) Unwrap() *Metric { + _tx, ok := m.config.driver.(*txDriver) + if !ok { + panic("ent: Metric is not a transactional entity") + } + m.config.driver = _tx.drv + return m +} + +// String implements the fmt.Stringer. +func (m *Metric) String() string { + var builder strings.Builder + builder.WriteString("Metric(") + builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) + builder.WriteString("generated_type=") + builder.WriteString(fmt.Sprintf("%v", m.GeneratedType)) + builder.WriteString(", ") + builder.WriteString("generated_by=") + builder.WriteString(m.GeneratedBy) + builder.WriteString(", ") + builder.WriteString("received_at=") + builder.WriteString(m.ReceivedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := m.PushedAt; v != nil { + builder.WriteString("pushed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("payload=") + builder.WriteString(m.Payload) + builder.WriteByte(')') + return builder.String() +} + +// Metrics is a parsable slice of Metric. +type Metrics []*Metric diff --git a/pkg/database/ent/metric/metric.go b/pkg/database/ent/metric/metric.go new file mode 100644 index 00000000000..78e88982220 --- /dev/null +++ b/pkg/database/ent/metric/metric.go @@ -0,0 +1,104 @@ +// Code generated by ent, DO NOT EDIT. + +package metric + +import ( + "fmt" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the metric type in the database. + Label = "metric" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldGeneratedType holds the string denoting the generated_type field in the database. + FieldGeneratedType = "generated_type" + // FieldGeneratedBy holds the string denoting the generated_by field in the database. + FieldGeneratedBy = "generated_by" + // FieldReceivedAt holds the string denoting the received_at field in the database. + FieldReceivedAt = "received_at" + // FieldPushedAt holds the string denoting the pushed_at field in the database. + FieldPushedAt = "pushed_at" + // FieldPayload holds the string denoting the payload field in the database. + FieldPayload = "payload" + // Table holds the table name of the metric in the database. + Table = "metrics" +) + +// Columns holds all SQL columns for metric fields. +var Columns = []string{ + FieldID, + FieldGeneratedType, + FieldGeneratedBy, + FieldReceivedAt, + FieldPushedAt, + FieldPayload, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// GeneratedType defines the type for the "generated_type" enum field. +type GeneratedType string + +// GeneratedType values. +const ( + GeneratedTypeLP GeneratedType = "LP" + GeneratedTypeRC GeneratedType = "RC" +) + +func (gt GeneratedType) String() string { + return string(gt) +} + +// GeneratedTypeValidator is a validator for the "generated_type" field enum values. It is called by the builders before save. +func GeneratedTypeValidator(gt GeneratedType) error { + switch gt { + case GeneratedTypeLP, GeneratedTypeRC: + return nil + default: + return fmt.Errorf("metric: invalid enum value for generated_type field: %q", gt) + } +} + +// OrderOption defines the ordering options for the Metric queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByGeneratedType orders the results by the generated_type field. +func ByGeneratedType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGeneratedType, opts...).ToFunc() +} + +// ByGeneratedBy orders the results by the generated_by field. +func ByGeneratedBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGeneratedBy, opts...).ToFunc() +} + +// ByReceivedAt orders the results by the received_at field. +func ByReceivedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldReceivedAt, opts...).ToFunc() +} + +// ByPushedAt orders the results by the pushed_at field. +func ByPushedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPushedAt, opts...).ToFunc() +} + +// ByPayload orders the results by the payload field. +func ByPayload(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPayload, opts...).ToFunc() +} diff --git a/pkg/database/ent/metric/where.go b/pkg/database/ent/metric/where.go new file mode 100644 index 00000000000..72bd9d93cd7 --- /dev/null +++ b/pkg/database/ent/metric/where.go @@ -0,0 +1,330 @@ +// Code generated by ent, DO NOT EDIT. + +package metric + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldID, id)) +} + +// GeneratedBy applies equality check predicate on the "generated_by" field. It's identical to GeneratedByEQ. +func GeneratedBy(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedBy, v)) +} + +// ReceivedAt applies equality check predicate on the "received_at" field. It's identical to ReceivedAtEQ. +func ReceivedAt(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldReceivedAt, v)) +} + +// PushedAt applies equality check predicate on the "pushed_at" field. It's identical to PushedAtEQ. +func PushedAt(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPushedAt, v)) +} + +// Payload applies equality check predicate on the "payload" field. It's identical to PayloadEQ. +func Payload(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPayload, v)) +} + +// GeneratedTypeEQ applies the EQ predicate on the "generated_type" field. +func GeneratedTypeEQ(v GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedType, v)) +} + +// GeneratedTypeNEQ applies the NEQ predicate on the "generated_type" field. +func GeneratedTypeNEQ(v GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldGeneratedType, v)) +} + +// GeneratedTypeIn applies the In predicate on the "generated_type" field. +func GeneratedTypeIn(vs ...GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldGeneratedType, vs...)) +} + +// GeneratedTypeNotIn applies the NotIn predicate on the "generated_type" field. +func GeneratedTypeNotIn(vs ...GeneratedType) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldGeneratedType, vs...)) +} + +// GeneratedByEQ applies the EQ predicate on the "generated_by" field. +func GeneratedByEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldGeneratedBy, v)) +} + +// GeneratedByNEQ applies the NEQ predicate on the "generated_by" field. +func GeneratedByNEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldGeneratedBy, v)) +} + +// GeneratedByIn applies the In predicate on the "generated_by" field. +func GeneratedByIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldGeneratedBy, vs...)) +} + +// GeneratedByNotIn applies the NotIn predicate on the "generated_by" field. +func GeneratedByNotIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldGeneratedBy, vs...)) +} + +// GeneratedByGT applies the GT predicate on the "generated_by" field. +func GeneratedByGT(v string) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldGeneratedBy, v)) +} + +// GeneratedByGTE applies the GTE predicate on the "generated_by" field. +func GeneratedByGTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldGeneratedBy, v)) +} + +// GeneratedByLT applies the LT predicate on the "generated_by" field. +func GeneratedByLT(v string) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldGeneratedBy, v)) +} + +// GeneratedByLTE applies the LTE predicate on the "generated_by" field. +func GeneratedByLTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldGeneratedBy, v)) +} + +// GeneratedByContains applies the Contains predicate on the "generated_by" field. +func GeneratedByContains(v string) predicate.Metric { + return predicate.Metric(sql.FieldContains(FieldGeneratedBy, v)) +} + +// GeneratedByHasPrefix applies the HasPrefix predicate on the "generated_by" field. +func GeneratedByHasPrefix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasPrefix(FieldGeneratedBy, v)) +} + +// GeneratedByHasSuffix applies the HasSuffix predicate on the "generated_by" field. +func GeneratedByHasSuffix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasSuffix(FieldGeneratedBy, v)) +} + +// GeneratedByEqualFold applies the EqualFold predicate on the "generated_by" field. +func GeneratedByEqualFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldEqualFold(FieldGeneratedBy, v)) +} + +// GeneratedByContainsFold applies the ContainsFold predicate on the "generated_by" field. +func GeneratedByContainsFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldContainsFold(FieldGeneratedBy, v)) +} + +// ReceivedAtEQ applies the EQ predicate on the "received_at" field. +func ReceivedAtEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldReceivedAt, v)) +} + +// ReceivedAtNEQ applies the NEQ predicate on the "received_at" field. +func ReceivedAtNEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldReceivedAt, v)) +} + +// ReceivedAtIn applies the In predicate on the "received_at" field. +func ReceivedAtIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldReceivedAt, vs...)) +} + +// ReceivedAtNotIn applies the NotIn predicate on the "received_at" field. +func ReceivedAtNotIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldReceivedAt, vs...)) +} + +// ReceivedAtGT applies the GT predicate on the "received_at" field. +func ReceivedAtGT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldReceivedAt, v)) +} + +// ReceivedAtGTE applies the GTE predicate on the "received_at" field. +func ReceivedAtGTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldReceivedAt, v)) +} + +// ReceivedAtLT applies the LT predicate on the "received_at" field. +func ReceivedAtLT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldReceivedAt, v)) +} + +// ReceivedAtLTE applies the LTE predicate on the "received_at" field. +func ReceivedAtLTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldReceivedAt, v)) +} + +// PushedAtEQ applies the EQ predicate on the "pushed_at" field. +func PushedAtEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPushedAt, v)) +} + +// PushedAtNEQ applies the NEQ predicate on the "pushed_at" field. +func PushedAtNEQ(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldPushedAt, v)) +} + +// PushedAtIn applies the In predicate on the "pushed_at" field. +func PushedAtIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldPushedAt, vs...)) +} + +// PushedAtNotIn applies the NotIn predicate on the "pushed_at" field. +func PushedAtNotIn(vs ...time.Time) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldPushedAt, vs...)) +} + +// PushedAtGT applies the GT predicate on the "pushed_at" field. +func PushedAtGT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldPushedAt, v)) +} + +// PushedAtGTE applies the GTE predicate on the "pushed_at" field. +func PushedAtGTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldPushedAt, v)) +} + +// PushedAtLT applies the LT predicate on the "pushed_at" field. +func PushedAtLT(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldPushedAt, v)) +} + +// PushedAtLTE applies the LTE predicate on the "pushed_at" field. +func PushedAtLTE(v time.Time) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldPushedAt, v)) +} + +// PushedAtIsNil applies the IsNil predicate on the "pushed_at" field. +func PushedAtIsNil() predicate.Metric { + return predicate.Metric(sql.FieldIsNull(FieldPushedAt)) +} + +// PushedAtNotNil applies the NotNil predicate on the "pushed_at" field. +func PushedAtNotNil() predicate.Metric { + return predicate.Metric(sql.FieldNotNull(FieldPushedAt)) +} + +// PayloadEQ applies the EQ predicate on the "payload" field. +func PayloadEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldEQ(FieldPayload, v)) +} + +// PayloadNEQ applies the NEQ predicate on the "payload" field. +func PayloadNEQ(v string) predicate.Metric { + return predicate.Metric(sql.FieldNEQ(FieldPayload, v)) +} + +// PayloadIn applies the In predicate on the "payload" field. +func PayloadIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldIn(FieldPayload, vs...)) +} + +// PayloadNotIn applies the NotIn predicate on the "payload" field. +func PayloadNotIn(vs ...string) predicate.Metric { + return predicate.Metric(sql.FieldNotIn(FieldPayload, vs...)) +} + +// PayloadGT applies the GT predicate on the "payload" field. +func PayloadGT(v string) predicate.Metric { + return predicate.Metric(sql.FieldGT(FieldPayload, v)) +} + +// PayloadGTE applies the GTE predicate on the "payload" field. +func PayloadGTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldGTE(FieldPayload, v)) +} + +// PayloadLT applies the LT predicate on the "payload" field. +func PayloadLT(v string) predicate.Metric { + return predicate.Metric(sql.FieldLT(FieldPayload, v)) +} + +// PayloadLTE applies the LTE predicate on the "payload" field. +func PayloadLTE(v string) predicate.Metric { + return predicate.Metric(sql.FieldLTE(FieldPayload, v)) +} + +// PayloadContains applies the Contains predicate on the "payload" field. +func PayloadContains(v string) predicate.Metric { + return predicate.Metric(sql.FieldContains(FieldPayload, v)) +} + +// PayloadHasPrefix applies the HasPrefix predicate on the "payload" field. +func PayloadHasPrefix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasPrefix(FieldPayload, v)) +} + +// PayloadHasSuffix applies the HasSuffix predicate on the "payload" field. +func PayloadHasSuffix(v string) predicate.Metric { + return predicate.Metric(sql.FieldHasSuffix(FieldPayload, v)) +} + +// PayloadEqualFold applies the EqualFold predicate on the "payload" field. +func PayloadEqualFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldEqualFold(FieldPayload, v)) +} + +// PayloadContainsFold applies the ContainsFold predicate on the "payload" field. +func PayloadContainsFold(v string) predicate.Metric { + return predicate.Metric(sql.FieldContainsFold(FieldPayload, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Metric) predicate.Metric { + return predicate.Metric(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Metric) predicate.Metric { + return predicate.Metric(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Metric) predicate.Metric { + return predicate.Metric(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/metric_create.go b/pkg/database/ent/metric_create.go new file mode 100644 index 00000000000..973cddd41d0 --- /dev/null +++ b/pkg/database/ent/metric_create.go @@ -0,0 +1,246 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +// MetricCreate is the builder for creating a Metric entity. +type MetricCreate struct { + config + mutation *MetricMutation + hooks []Hook +} + +// SetGeneratedType sets the "generated_type" field. +func (mc *MetricCreate) SetGeneratedType(mt metric.GeneratedType) *MetricCreate { + mc.mutation.SetGeneratedType(mt) + return mc +} + +// SetGeneratedBy sets the "generated_by" field. +func (mc *MetricCreate) SetGeneratedBy(s string) *MetricCreate { + mc.mutation.SetGeneratedBy(s) + return mc +} + +// SetReceivedAt sets the "received_at" field. +func (mc *MetricCreate) SetReceivedAt(t time.Time) *MetricCreate { + mc.mutation.SetReceivedAt(t) + return mc +} + +// SetPushedAt sets the "pushed_at" field. +func (mc *MetricCreate) SetPushedAt(t time.Time) *MetricCreate { + mc.mutation.SetPushedAt(t) + return mc +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (mc *MetricCreate) SetNillablePushedAt(t *time.Time) *MetricCreate { + if t != nil { + mc.SetPushedAt(*t) + } + return mc +} + +// SetPayload sets the "payload" field. +func (mc *MetricCreate) SetPayload(s string) *MetricCreate { + mc.mutation.SetPayload(s) + return mc +} + +// Mutation returns the MetricMutation object of the builder. +func (mc *MetricCreate) Mutation() *MetricMutation { + return mc.mutation +} + +// Save creates the Metric in the database. +func (mc *MetricCreate) Save(ctx context.Context) (*Metric, error) { + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (mc *MetricCreate) SaveX(ctx context.Context) *Metric { + v, err := mc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mc *MetricCreate) Exec(ctx context.Context) error { + _, err := mc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mc *MetricCreate) ExecX(ctx context.Context) { + if err := mc.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mc *MetricCreate) check() error { + if _, ok := mc.mutation.GeneratedType(); !ok { + return &ValidationError{Name: "generated_type", err: errors.New(`ent: missing required field "Metric.generated_type"`)} + } + if v, ok := mc.mutation.GeneratedType(); ok { + if err := metric.GeneratedTypeValidator(v); err != nil { + return &ValidationError{Name: "generated_type", err: fmt.Errorf(`ent: validator failed for field "Metric.generated_type": %w`, err)} + } + } + if _, ok := mc.mutation.GeneratedBy(); !ok { + return &ValidationError{Name: "generated_by", err: errors.New(`ent: missing required field "Metric.generated_by"`)} + } + if _, ok := mc.mutation.ReceivedAt(); !ok { + return &ValidationError{Name: "received_at", err: errors.New(`ent: missing required field "Metric.received_at"`)} + } + if _, ok := mc.mutation.Payload(); !ok { + return &ValidationError{Name: "payload", err: errors.New(`ent: missing required field "Metric.payload"`)} + } + return nil +} + +func (mc *MetricCreate) sqlSave(ctx context.Context) (*Metric, error) { + if err := mc.check(); err != nil { + return nil, err + } + _node, _spec := mc.createSpec() + if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true + return _node, nil +} + +func (mc *MetricCreate) createSpec() (*Metric, *sqlgraph.CreateSpec) { + var ( + _node = &Metric{config: mc.config} + _spec = sqlgraph.NewCreateSpec(metric.Table, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + ) + if value, ok := mc.mutation.GeneratedType(); ok { + _spec.SetField(metric.FieldGeneratedType, field.TypeEnum, value) + _node.GeneratedType = value + } + if value, ok := mc.mutation.GeneratedBy(); ok { + _spec.SetField(metric.FieldGeneratedBy, field.TypeString, value) + _node.GeneratedBy = value + } + if value, ok := mc.mutation.ReceivedAt(); ok { + _spec.SetField(metric.FieldReceivedAt, field.TypeTime, value) + _node.ReceivedAt = value + } + if value, ok := mc.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + _node.PushedAt = &value + } + if value, ok := mc.mutation.Payload(); ok { + _spec.SetField(metric.FieldPayload, field.TypeString, value) + _node.Payload = value + } + return _node, _spec +} + +// MetricCreateBulk is the builder for creating many Metric entities in bulk. +type MetricCreateBulk struct { + config + err error + builders []*MetricCreate +} + +// Save creates the Metric entities in the database. +func (mcb *MetricCreateBulk) Save(ctx context.Context) ([]*Metric, error) { + if mcb.err != nil { + return nil, mcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) + nodes := make([]*Metric, len(mcb.builders)) + mutators := make([]Mutator, len(mcb.builders)) + for i := range mcb.builders { + func(i int, root context.Context) { + builder := mcb.builders[i] + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*MetricMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, mcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, mcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (mcb *MetricCreateBulk) SaveX(ctx context.Context) []*Metric { + v, err := mcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mcb *MetricCreateBulk) Exec(ctx context.Context) error { + _, err := mcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mcb *MetricCreateBulk) ExecX(ctx context.Context) { + if err := mcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/metric_delete.go b/pkg/database/ent/metric_delete.go new file mode 100644 index 00000000000..d6606680a6a --- /dev/null +++ b/pkg/database/ent/metric_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricDelete is the builder for deleting a Metric entity. +type MetricDelete struct { + config + hooks []Hook + mutation *MetricMutation +} + +// Where appends a list predicates to the MetricDelete builder. +func (md *MetricDelete) Where(ps ...predicate.Metric) *MetricDelete { + md.mutation.Where(ps...) + return md +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (md *MetricDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (md *MetricDelete) ExecX(ctx context.Context) int { + n, err := md.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (md *MetricDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(metric.Table, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + if ps := md.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, md.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + md.mutation.done = true + return affected, err +} + +// MetricDeleteOne is the builder for deleting a single Metric entity. +type MetricDeleteOne struct { + md *MetricDelete +} + +// Where appends a list predicates to the MetricDelete builder. +func (mdo *MetricDeleteOne) Where(ps ...predicate.Metric) *MetricDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + +// Exec executes the deletion query. +func (mdo *MetricDeleteOne) Exec(ctx context.Context) error { + n, err := mdo.md.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{metric.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (mdo *MetricDeleteOne) ExecX(ctx context.Context) { + if err := mdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/metric_query.go b/pkg/database/ent/metric_query.go new file mode 100644 index 00000000000..6e1c6f08b4a --- /dev/null +++ b/pkg/database/ent/metric_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricQuery is the builder for querying Metric entities. +type MetricQuery struct { + config + ctx *QueryContext + order []metric.OrderOption + inters []Interceptor + predicates []predicate.Metric + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the MetricQuery builder. +func (mq *MetricQuery) Where(ps ...predicate.Metric) *MetricQuery { + mq.predicates = append(mq.predicates, ps...) + return mq +} + +// Limit the number of records to be returned by this query. +func (mq *MetricQuery) Limit(limit int) *MetricQuery { + mq.ctx.Limit = &limit + return mq +} + +// Offset to start from. +func (mq *MetricQuery) Offset(offset int) *MetricQuery { + mq.ctx.Offset = &offset + return mq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (mq *MetricQuery) Unique(unique bool) *MetricQuery { + mq.ctx.Unique = &unique + return mq +} + +// Order specifies how the records should be ordered. +func (mq *MetricQuery) Order(o ...metric.OrderOption) *MetricQuery { + mq.order = append(mq.order, o...) + return mq +} + +// First returns the first Metric entity from the query. +// Returns a *NotFoundError when no Metric was found. +func (mq *MetricQuery) First(ctx context.Context) (*Metric, error) { + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{metric.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (mq *MetricQuery) FirstX(ctx context.Context) *Metric { + node, err := mq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Metric ID from the query. +// Returns a *NotFoundError when no Metric ID was found. +func (mq *MetricQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{metric.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (mq *MetricQuery) FirstIDX(ctx context.Context) int { + id, err := mq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Metric entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Metric entity is found. +// Returns a *NotFoundError when no Metric entities are found. +func (mq *MetricQuery) Only(ctx context.Context) (*Metric, error) { + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{metric.Label} + default: + return nil, &NotSingularError{metric.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (mq *MetricQuery) OnlyX(ctx context.Context) *Metric { + node, err := mq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Metric ID in the query. +// Returns a *NotSingularError when more than one Metric ID is found. +// Returns a *NotFoundError when no entities are found. +func (mq *MetricQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{metric.Label} + default: + err = &NotSingularError{metric.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (mq *MetricQuery) OnlyIDX(ctx context.Context) int { + id, err := mq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Metrics. +func (mq *MetricQuery) All(ctx context.Context) ([]*Metric, error) { + ctx = setContextOp(ctx, mq.ctx, "All") + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Metric, *MetricQuery]() + return withInterceptors[[]*Metric](ctx, mq, qr, mq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (mq *MetricQuery) AllX(ctx context.Context) []*Metric { + nodes, err := mq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Metric IDs. +func (mq *MetricQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(metric.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (mq *MetricQuery) IDsX(ctx context.Context) []int { + ids, err := mq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (mq *MetricQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") + if err := mq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, mq, querierCount[*MetricQuery](), mq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (mq *MetricQuery) CountX(ctx context.Context) int { + count, err := mq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (mq *MetricQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (mq *MetricQuery) ExistX(ctx context.Context) bool { + exist, err := mq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the MetricQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (mq *MetricQuery) Clone() *MetricQuery { + if mq == nil { + return nil + } + return &MetricQuery{ + config: mq.config, + ctx: mq.ctx.Clone(), + order: append([]metric.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), + predicates: append([]predicate.Metric{}, mq.predicates...), + // clone intermediate query. + sql: mq.sql.Clone(), + path: mq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Metric.Query(). +// GroupBy(metric.FieldGeneratedType). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (mq *MetricQuery) GroupBy(field string, fields ...string) *MetricGroupBy { + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetricGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields + grbuild.label = metric.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// GeneratedType metric.GeneratedType `json:"generated_type,omitempty"` +// } +// +// client.Metric.Query(). +// Select(metric.FieldGeneratedType). +// Scan(ctx, &v) +func (mq *MetricQuery) Select(fields ...string) *MetricSelect { + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetricSelect{MetricQuery: mq} + sbuild.label = metric.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetricSelect configured with the given aggregations. +func (mq *MetricQuery) Aggregate(fns ...AggregateFunc) *MetricSelect { + return mq.Select().Aggregate(fns...) +} + +func (mq *MetricQuery) prepareQuery(ctx context.Context) error { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { + if !metric.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if mq.path != nil { + prev, err := mq.path(ctx) + if err != nil { + return err + } + mq.sql = prev + } + return nil +} + +func (mq *MetricQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Metric, error) { + var ( + nodes = []*Metric{} + _spec = mq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Metric).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Metric{config: mq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (mq *MetricQuery) sqlCount(ctx context.Context) (int, error) { + _spec := mq.querySpec() + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, mq.driver, _spec) +} + +func (mq *MetricQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true + } + if fields := mq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metric.FieldID) + for i := range fields { + if fields[i] != metric.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := mq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := mq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := mq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := mq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (mq *MetricQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(mq.driver.Dialect()) + t1 := builder.Table(metric.Table) + columns := mq.ctx.Fields + if len(columns) == 0 { + columns = metric.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if mq.sql != nil { + selector = mq.sql + selector.Select(selector.Columns(columns...)...) + } + if mq.ctx.Unique != nil && *mq.ctx.Unique { + selector.Distinct() + } + for _, p := range mq.predicates { + p(selector) + } + for _, p := range mq.order { + p(selector) + } + if offset := mq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := mq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// MetricGroupBy is the group-by builder for Metric entities. +type MetricGroupBy struct { + selector + build *MetricQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (mgb *MetricGroupBy) Aggregate(fns ...AggregateFunc) *MetricGroupBy { + mgb.fns = append(mgb.fns, fns...) + return mgb +} + +// Scan applies the selector query and scans the result into the given value. +func (mgb *MetricGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetricQuery, *MetricGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) +} + +func (mgb *MetricGroupBy) sqlScan(ctx context.Context, root *MetricQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*mgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// MetricSelect is the builder for selecting fields of Metric entities. +type MetricSelect struct { + *MetricQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetricSelect) Aggregate(fns ...AggregateFunc) *MetricSelect { + ms.fns = append(ms.fns, fns...) + return ms +} + +// Scan applies the selector query and scans the result into the given value. +func (ms *MetricSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") + if err := ms.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetricQuery, *MetricSelect](ctx, ms.MetricQuery, ms, ms.inters, v) +} + +func (ms *MetricSelect) sqlScan(ctx context.Context, root *MetricQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ms.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/metric_update.go b/pkg/database/ent/metric_update.go new file mode 100644 index 00000000000..4da33dd6ce9 --- /dev/null +++ b/pkg/database/ent/metric_update.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// MetricUpdate is the builder for updating Metric entities. +type MetricUpdate struct { + config + hooks []Hook + mutation *MetricMutation +} + +// Where appends a list predicates to the MetricUpdate builder. +func (mu *MetricUpdate) Where(ps ...predicate.Metric) *MetricUpdate { + mu.mutation.Where(ps...) + return mu +} + +// SetPushedAt sets the "pushed_at" field. +func (mu *MetricUpdate) SetPushedAt(t time.Time) *MetricUpdate { + mu.mutation.SetPushedAt(t) + return mu +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (mu *MetricUpdate) SetNillablePushedAt(t *time.Time) *MetricUpdate { + if t != nil { + mu.SetPushedAt(*t) + } + return mu +} + +// ClearPushedAt clears the value of the "pushed_at" field. +func (mu *MetricUpdate) ClearPushedAt() *MetricUpdate { + mu.mutation.ClearPushedAt() + return mu +} + +// Mutation returns the MetricMutation object of the builder. +func (mu *MetricUpdate) Mutation() *MetricMutation { + return mu.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (mu *MetricUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mu *MetricUpdate) SaveX(ctx context.Context) int { + affected, err := mu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (mu *MetricUpdate) Exec(ctx context.Context) error { + _, err := mu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mu *MetricUpdate) ExecX(ctx context.Context) { + if err := mu.Exec(ctx); err != nil { + panic(err) + } +} + +func (mu *MetricUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + if ps := mu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mu.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + } + if mu.mutation.PushedAtCleared() { + _spec.ClearField(metric.FieldPushedAt, field.TypeTime) + } + if n, err = sqlgraph.UpdateNodes(ctx, mu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metric.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + mu.mutation.done = true + return n, nil +} + +// MetricUpdateOne is the builder for updating a single Metric entity. +type MetricUpdateOne struct { + config + fields []string + hooks []Hook + mutation *MetricMutation +} + +// SetPushedAt sets the "pushed_at" field. +func (muo *MetricUpdateOne) SetPushedAt(t time.Time) *MetricUpdateOne { + muo.mutation.SetPushedAt(t) + return muo +} + +// SetNillablePushedAt sets the "pushed_at" field if the given value is not nil. +func (muo *MetricUpdateOne) SetNillablePushedAt(t *time.Time) *MetricUpdateOne { + if t != nil { + muo.SetPushedAt(*t) + } + return muo +} + +// ClearPushedAt clears the value of the "pushed_at" field. +func (muo *MetricUpdateOne) ClearPushedAt() *MetricUpdateOne { + muo.mutation.ClearPushedAt() + return muo +} + +// Mutation returns the MetricMutation object of the builder. +func (muo *MetricUpdateOne) Mutation() *MetricMutation { + return muo.mutation +} + +// Where appends a list predicates to the MetricUpdate builder. +func (muo *MetricUpdateOne) Where(ps ...predicate.Metric) *MetricUpdateOne { + muo.mutation.Where(ps...) + return muo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (muo *MetricUpdateOne) Select(field string, fields ...string) *MetricUpdateOne { + muo.fields = append([]string{field}, fields...) + return muo +} + +// Save executes the query and returns the updated Metric entity. +func (muo *MetricUpdateOne) Save(ctx context.Context) (*Metric, error) { + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (muo *MetricUpdateOne) SaveX(ctx context.Context) *Metric { + node, err := muo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (muo *MetricUpdateOne) Exec(ctx context.Context) error { + _, err := muo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (muo *MetricUpdateOne) ExecX(ctx context.Context) { + if err := muo.Exec(ctx); err != nil { + panic(err) + } +} + +func (muo *MetricUpdateOne) sqlSave(ctx context.Context) (_node *Metric, err error) { + _spec := sqlgraph.NewUpdateSpec(metric.Table, metric.Columns, sqlgraph.NewFieldSpec(metric.FieldID, field.TypeInt)) + id, ok := muo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Metric.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := muo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metric.FieldID) + for _, f := range fields { + if !metric.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != metric.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := muo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := muo.mutation.PushedAt(); ok { + _spec.SetField(metric.FieldPushedAt, field.TypeTime, value) + } + if muo.mutation.PushedAtCleared() { + _spec.ClearField(metric.FieldPushedAt, field.TypeTime) + } + _node = &Metric{config: muo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, muo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metric.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + muo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 375fd4e784a..986f5bc8c67 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -11,8 +11,8 @@ var ( // AlertsColumns holds the columns for the "alerts" table. AlertsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "scenario", Type: field.TypeString}, {Name: "bucket_id", Type: field.TypeString, Nullable: true, Default: ""}, {Name: "message", Type: field.TypeString, Nullable: true, Default: ""}, @@ -34,6 +34,7 @@ var ( {Name: "scenario_hash", Type: field.TypeString, Nullable: true}, {Name: "simulated", Type: field.TypeBool, Default: false}, {Name: "uuid", Type: field.TypeString, Nullable: true}, + {Name: "remediation", Type: field.TypeBool, Nullable: true}, {Name: "machine_alerts", Type: field.TypeInt, Nullable: true}, } // AlertsTable holds the schema information for the "alerts" table. @@ -44,7 +45,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "alerts_machines_alerts", - Columns: []*schema.Column{AlertsColumns[24]}, + Columns: []*schema.Column{AlertsColumns[25]}, RefColumns: []*schema.Column{MachinesColumns[0]}, OnDelete: schema.SetNull, }, @@ -60,17 +61,19 @@ var ( // BouncersColumns holds the columns for the "bouncers" table. BouncersColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "api_key", Type: field.TypeString}, {Name: "revoked", Type: field.TypeBool}, {Name: "ip_address", Type: field.TypeString, Nullable: true, Default: ""}, {Name: "type", Type: field.TypeString, Nullable: true}, {Name: "version", Type: field.TypeString, Nullable: true}, - {Name: "until", Type: field.TypeTime, Nullable: true}, - {Name: "last_pull", Type: field.TypeTime}, + {Name: "last_pull", Type: field.TypeTime, Nullable: true}, {Name: "auth_type", Type: field.TypeString, Default: "api-key"}, + {Name: "osname", Type: field.TypeString, Nullable: true}, + {Name: "osversion", Type: field.TypeString, Nullable: true}, + {Name: "featureflags", Type: field.TypeString, Nullable: true}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ @@ -81,8 +84,8 @@ var ( // ConfigItemsColumns holds the columns for the "config_items" table. ConfigItemsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "value", Type: field.TypeString}, } @@ -95,8 +98,8 @@ var ( // DecisionsColumns holds the columns for the "decisions" table. DecisionsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, {Name: "scenario", Type: field.TypeString}, {Name: "type", Type: field.TypeString}, @@ -151,8 +154,8 @@ var ( // EventsColumns holds the columns for the "events" table. EventsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "time", Type: field.TypeTime}, {Name: "serialized", Type: field.TypeString, Size: 8191}, {Name: "alert_events", Type: field.TypeInt, Nullable: true}, @@ -178,11 +181,23 @@ var ( }, }, } + // LocksColumns holds the columns for the "locks" table. + LocksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "name", Type: field.TypeString, Unique: true}, + {Name: "created_at", Type: field.TypeTime}, + } + // LocksTable holds the schema information for the "locks" table. + LocksTable = &schema.Table{ + Name: "locks", + Columns: LocksColumns, + PrimaryKey: []*schema.Column{LocksColumns[0]}, + } // MachinesColumns holds the columns for the "machines" table. MachinesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "last_push", Type: field.TypeTime, Nullable: true}, {Name: "last_heartbeat", Type: field.TypeTime, Nullable: true}, {Name: "machine_id", Type: field.TypeString, Unique: true}, @@ -191,8 +206,12 @@ var ( {Name: "scenarios", Type: field.TypeString, Nullable: true, Size: 100000}, {Name: "version", Type: field.TypeString, Nullable: true}, {Name: "is_validated", Type: field.TypeBool, Default: false}, - {Name: "status", Type: field.TypeString, Nullable: true}, {Name: "auth_type", Type: field.TypeString, Default: "password"}, + {Name: "osname", Type: field.TypeString, Nullable: true}, + {Name: "osversion", Type: field.TypeString, Nullable: true}, + {Name: "featureflags", Type: field.TypeString, Nullable: true}, + {Name: "hubstate", Type: field.TypeJSON, Nullable: true}, + {Name: "datasources", Type: field.TypeJSON, Nullable: true}, } // MachinesTable holds the schema information for the "machines" table. MachinesTable = &schema.Table{ @@ -203,8 +222,8 @@ var ( // MetaColumns holds the columns for the "meta" table. MetaColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "created_at", Type: field.TypeTime, Nullable: true}, - {Name: "updated_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, {Name: "key", Type: field.TypeString}, {Name: "value", Type: field.TypeString, Size: 4095}, {Name: "alert_metas", Type: field.TypeInt, Nullable: true}, @@ -230,6 +249,21 @@ var ( }, }, } + // MetricsColumns holds the columns for the "metrics" table. + MetricsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "generated_type", Type: field.TypeEnum, Enums: []string{"LP", "RC"}}, + {Name: "generated_by", Type: field.TypeString}, + {Name: "received_at", Type: field.TypeTime}, + {Name: "pushed_at", Type: field.TypeTime, Nullable: true}, + {Name: "payload", Type: field.TypeString, Size: 2147483647}, + } + // MetricsTable holds the schema information for the "metrics" table. + MetricsTable = &schema.Table{ + Name: "metrics", + Columns: MetricsColumns, + PrimaryKey: []*schema.Column{MetricsColumns[0]}, + } // Tables holds all the tables in the schema. Tables = []*schema.Table{ AlertsTable, @@ -237,8 +271,10 @@ var ( ConfigItemsTable, DecisionsTable, EventsTable, + LocksTable, MachinesTable, MetaTable, + MetricsTable, } ) diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 907c1ef015e..5c6596f3db4 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -9,16 +9,19 @@ import ( "sync" "time" + "entgo.io/ent" + "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" - - "entgo.io/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" ) const ( @@ -35,8 +38,10 @@ const ( TypeConfigItem = "ConfigItem" TypeDecision = "Decision" TypeEvent = "Event" + TypeLock = "Lock" TypeMachine = "Machine" TypeMeta = "Meta" + TypeMetric = "Metric" ) // AlertMutation represents an operation that mutates the Alert nodes in the graph. @@ -72,6 +77,7 @@ type AlertMutation struct { scenarioHash *string simulated *bool uuid *string + remediation *bool clearedFields map[string]struct{} owner *int clearedowner bool @@ -204,7 +210,7 @@ func (m *AlertMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Alert entity. // If the Alert object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -218,22 +224,9 @@ func (m *AlertMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *AlertMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[alert.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *AlertMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[alert.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *AlertMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, alert.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -253,7 +246,7 @@ func (m *AlertMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Alert entity. // If the Alert object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -267,22 +260,9 @@ func (m *AlertMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *AlertMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[alert.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *AlertMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[alert.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *AlertMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, alert.FieldUpdatedAt) } // SetScenario sets the "scenario" field. @@ -1372,6 +1352,55 @@ func (m *AlertMutation) ResetUUID() { delete(m.clearedFields, alert.FieldUUID) } +// SetRemediation sets the "remediation" field. +func (m *AlertMutation) SetRemediation(b bool) { + m.remediation = &b +} + +// Remediation returns the value of the "remediation" field in the mutation. +func (m *AlertMutation) Remediation() (r bool, exists bool) { + v := m.remediation + if v == nil { + return + } + return *v, true +} + +// OldRemediation returns the old "remediation" field's value of the Alert entity. +// If the Alert object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AlertMutation) OldRemediation(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRemediation is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRemediation requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRemediation: %w", err) + } + return oldValue.Remediation, nil +} + +// ClearRemediation clears the value of the "remediation" field. +func (m *AlertMutation) ClearRemediation() { + m.remediation = nil + m.clearedFields[alert.FieldRemediation] = struct{}{} +} + +// RemediationCleared returns if the "remediation" field was cleared in this mutation. +func (m *AlertMutation) RemediationCleared() bool { + _, ok := m.clearedFields[alert.FieldRemediation] + return ok +} + +// ResetRemediation resets all changes to the "remediation" field. +func (m *AlertMutation) ResetRemediation() { + m.remediation = nil + delete(m.clearedFields, alert.FieldRemediation) +} + // SetOwnerID sets the "owner" edge to the Machine entity by id. func (m *AlertMutation) SetOwnerID(id int) { m.owner = &id @@ -1578,11 +1607,26 @@ func (m *AlertMutation) Where(ps ...predicate.Alert) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the AlertMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AlertMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Alert, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *AlertMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *AlertMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Alert). func (m *AlertMutation) Type() string { return m.typ @@ -1592,7 +1636,7 @@ func (m *AlertMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AlertMutation) Fields() []string { - fields := make([]string, 0, 23) + fields := make([]string, 0, 24) if m.created_at != nil { fields = append(fields, alert.FieldCreatedAt) } @@ -1662,6 +1706,9 @@ func (m *AlertMutation) Fields() []string { if m.uuid != nil { fields = append(fields, alert.FieldUUID) } + if m.remediation != nil { + fields = append(fields, alert.FieldRemediation) + } return fields } @@ -1716,6 +1763,8 @@ func (m *AlertMutation) Field(name string) (ent.Value, bool) { return m.Simulated() case alert.FieldUUID: return m.UUID() + case alert.FieldRemediation: + return m.Remediation() } return nil, false } @@ -1771,6 +1820,8 @@ func (m *AlertMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSimulated(ctx) case alert.FieldUUID: return m.OldUUID(ctx) + case alert.FieldRemediation: + return m.OldRemediation(ctx) } return nil, fmt.Errorf("unknown Alert field %s", name) } @@ -1941,6 +1992,13 @@ func (m *AlertMutation) SetField(name string, value ent.Value) error { } m.SetUUID(v) return nil + case alert.FieldRemediation: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRemediation(v) + return nil } return fmt.Errorf("unknown Alert field %s", name) } @@ -2022,12 +2080,6 @@ func (m *AlertMutation) AddField(name string, value ent.Value) error { // mutation. func (m *AlertMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(alert.FieldCreatedAt) { - fields = append(fields, alert.FieldCreatedAt) - } - if m.FieldCleared(alert.FieldUpdatedAt) { - fields = append(fields, alert.FieldUpdatedAt) - } if m.FieldCleared(alert.FieldBucketId) { fields = append(fields, alert.FieldBucketId) } @@ -2085,6 +2137,9 @@ func (m *AlertMutation) ClearedFields() []string { if m.FieldCleared(alert.FieldUUID) { fields = append(fields, alert.FieldUUID) } + if m.FieldCleared(alert.FieldRemediation) { + fields = append(fields, alert.FieldRemediation) + } return fields } @@ -2099,12 +2154,6 @@ func (m *AlertMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *AlertMutation) ClearField(name string) error { switch name { - case alert.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case alert.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case alert.FieldBucketId: m.ClearBucketId() return nil @@ -2162,6 +2211,9 @@ func (m *AlertMutation) ClearField(name string) error { case alert.FieldUUID: m.ClearUUID() return nil + case alert.FieldRemediation: + m.ClearRemediation() + return nil } return fmt.Errorf("unknown Alert nullable field %s", name) } @@ -2239,6 +2291,9 @@ func (m *AlertMutation) ResetField(name string) error { case alert.FieldUUID: m.ResetUUID() return nil + case alert.FieldRemediation: + m.ResetRemediation() + return nil } return fmt.Errorf("unknown Alert field %s", name) } @@ -2411,9 +2466,11 @@ type BouncerMutation struct { ip_address *string _type *string version *string - until *time.Time last_pull *time.Time auth_type *string + osname *string + osversion *string + featureflags *string clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -2535,7 +2592,7 @@ func (m *BouncerMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -2549,22 +2606,9 @@ func (m *BouncerMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err e return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *BouncerMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[bouncer.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *BouncerMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[bouncer.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *BouncerMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, bouncer.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -2584,7 +2628,7 @@ func (m *BouncerMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -2598,22 +2642,9 @@ func (m *BouncerMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err e return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *BouncerMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[bouncer.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *BouncerMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[bouncer.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *BouncerMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, bouncer.FieldUpdatedAt) } // SetName sets the "name" field. @@ -2871,55 +2902,6 @@ func (m *BouncerMutation) ResetVersion() { delete(m.clearedFields, bouncer.FieldVersion) } -// SetUntil sets the "until" field. -func (m *BouncerMutation) SetUntil(t time.Time) { - m.until = &t -} - -// Until returns the value of the "until" field in the mutation. -func (m *BouncerMutation) Until() (r time.Time, exists bool) { - v := m.until - if v == nil { - return - } - return *v, true -} - -// OldUntil returns the old "until" field's value of the Bouncer entity. -// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldUntil(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUntil is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUntil requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUntil: %w", err) - } - return oldValue.Until, nil -} - -// ClearUntil clears the value of the "until" field. -func (m *BouncerMutation) ClearUntil() { - m.until = nil - m.clearedFields[bouncer.FieldUntil] = struct{}{} -} - -// UntilCleared returns if the "until" field was cleared in this mutation. -func (m *BouncerMutation) UntilCleared() bool { - _, ok := m.clearedFields[bouncer.FieldUntil] - return ok -} - -// ResetUntil resets all changes to the "until" field. -func (m *BouncerMutation) ResetUntil() { - m.until = nil - delete(m.clearedFields, bouncer.FieldUntil) -} - // SetLastPull sets the "last_pull" field. func (m *BouncerMutation) SetLastPull(t time.Time) { m.last_pull = &t @@ -2937,7 +2919,7 @@ func (m *BouncerMutation) LastPull() (r time.Time, exists bool) { // OldLastPull returns the old "last_pull" field's value of the Bouncer entity. // If the Bouncer object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *BouncerMutation) OldLastPull(ctx context.Context) (v time.Time, err error) { +func (m *BouncerMutation) OldLastPull(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldLastPull is only allowed on UpdateOne operations") } @@ -2951,9 +2933,22 @@ func (m *BouncerMutation) OldLastPull(ctx context.Context) (v time.Time, err err return oldValue.LastPull, nil } +// ClearLastPull clears the value of the "last_pull" field. +func (m *BouncerMutation) ClearLastPull() { + m.last_pull = nil + m.clearedFields[bouncer.FieldLastPull] = struct{}{} +} + +// LastPullCleared returns if the "last_pull" field was cleared in this mutation. +func (m *BouncerMutation) LastPullCleared() bool { + _, ok := m.clearedFields[bouncer.FieldLastPull] + return ok +} + // ResetLastPull resets all changes to the "last_pull" field. func (m *BouncerMutation) ResetLastPull() { m.last_pull = nil + delete(m.clearedFields, bouncer.FieldLastPull) } // SetAuthType sets the "auth_type" field. @@ -2992,16 +2987,178 @@ func (m *BouncerMutation) ResetAuthType() { m.auth_type = nil } +// SetOsname sets the "osname" field. +func (m *BouncerMutation) SetOsname(s string) { + m.osname = &s +} + +// Osname returns the value of the "osname" field in the mutation. +func (m *BouncerMutation) Osname() (r string, exists bool) { + v := m.osname + if v == nil { + return + } + return *v, true +} + +// OldOsname returns the old "osname" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldOsname(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsname is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsname requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsname: %w", err) + } + return oldValue.Osname, nil +} + +// ClearOsname clears the value of the "osname" field. +func (m *BouncerMutation) ClearOsname() { + m.osname = nil + m.clearedFields[bouncer.FieldOsname] = struct{}{} +} + +// OsnameCleared returns if the "osname" field was cleared in this mutation. +func (m *BouncerMutation) OsnameCleared() bool { + _, ok := m.clearedFields[bouncer.FieldOsname] + return ok +} + +// ResetOsname resets all changes to the "osname" field. +func (m *BouncerMutation) ResetOsname() { + m.osname = nil + delete(m.clearedFields, bouncer.FieldOsname) +} + +// SetOsversion sets the "osversion" field. +func (m *BouncerMutation) SetOsversion(s string) { + m.osversion = &s +} + +// Osversion returns the value of the "osversion" field in the mutation. +func (m *BouncerMutation) Osversion() (r string, exists bool) { + v := m.osversion + if v == nil { + return + } + return *v, true +} + +// OldOsversion returns the old "osversion" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldOsversion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsversion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsversion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsversion: %w", err) + } + return oldValue.Osversion, nil +} + +// ClearOsversion clears the value of the "osversion" field. +func (m *BouncerMutation) ClearOsversion() { + m.osversion = nil + m.clearedFields[bouncer.FieldOsversion] = struct{}{} +} + +// OsversionCleared returns if the "osversion" field was cleared in this mutation. +func (m *BouncerMutation) OsversionCleared() bool { + _, ok := m.clearedFields[bouncer.FieldOsversion] + return ok +} + +// ResetOsversion resets all changes to the "osversion" field. +func (m *BouncerMutation) ResetOsversion() { + m.osversion = nil + delete(m.clearedFields, bouncer.FieldOsversion) +} + +// SetFeatureflags sets the "featureflags" field. +func (m *BouncerMutation) SetFeatureflags(s string) { + m.featureflags = &s +} + +// Featureflags returns the value of the "featureflags" field in the mutation. +func (m *BouncerMutation) Featureflags() (r string, exists bool) { + v := m.featureflags + if v == nil { + return + } + return *v, true +} + +// OldFeatureflags returns the old "featureflags" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldFeatureflags(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeatureflags is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeatureflags requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeatureflags: %w", err) + } + return oldValue.Featureflags, nil +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (m *BouncerMutation) ClearFeatureflags() { + m.featureflags = nil + m.clearedFields[bouncer.FieldFeatureflags] = struct{}{} +} + +// FeatureflagsCleared returns if the "featureflags" field was cleared in this mutation. +func (m *BouncerMutation) FeatureflagsCleared() bool { + _, ok := m.clearedFields[bouncer.FieldFeatureflags] + return ok +} + +// ResetFeatureflags resets all changes to the "featureflags" field. +func (m *BouncerMutation) ResetFeatureflags() { + m.featureflags = nil + delete(m.clearedFields, bouncer.FieldFeatureflags) +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the BouncerMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *BouncerMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Bouncer, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *BouncerMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *BouncerMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Bouncer). func (m *BouncerMutation) Type() string { return m.typ @@ -3011,7 +3168,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 13) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -3036,15 +3193,21 @@ func (m *BouncerMutation) Fields() []string { if m.version != nil { fields = append(fields, bouncer.FieldVersion) } - if m.until != nil { - fields = append(fields, bouncer.FieldUntil) - } if m.last_pull != nil { fields = append(fields, bouncer.FieldLastPull) } if m.auth_type != nil { fields = append(fields, bouncer.FieldAuthType) } + if m.osname != nil { + fields = append(fields, bouncer.FieldOsname) + } + if m.osversion != nil { + fields = append(fields, bouncer.FieldOsversion) + } + if m.featureflags != nil { + fields = append(fields, bouncer.FieldFeatureflags) + } return fields } @@ -3069,12 +3232,16 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.GetType() case bouncer.FieldVersion: return m.Version() - case bouncer.FieldUntil: - return m.Until() case bouncer.FieldLastPull: return m.LastPull() case bouncer.FieldAuthType: return m.AuthType() + case bouncer.FieldOsname: + return m.Osname() + case bouncer.FieldOsversion: + return m.Osversion() + case bouncer.FieldFeatureflags: + return m.Featureflags() } return nil, false } @@ -3100,12 +3267,16 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldType(ctx) case bouncer.FieldVersion: return m.OldVersion(ctx) - case bouncer.FieldUntil: - return m.OldUntil(ctx) case bouncer.FieldLastPull: return m.OldLastPull(ctx) case bouncer.FieldAuthType: return m.OldAuthType(ctx) + case bouncer.FieldOsname: + return m.OldOsname(ctx) + case bouncer.FieldOsversion: + return m.OldOsversion(ctx) + case bouncer.FieldFeatureflags: + return m.OldFeatureflags(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3171,13 +3342,6 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetVersion(v) return nil - case bouncer.FieldUntil: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUntil(v) - return nil case bouncer.FieldLastPull: v, ok := value.(time.Time) if !ok { @@ -3192,6 +3356,27 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetAuthType(v) return nil + case bouncer.FieldOsname: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsname(v) + return nil + case bouncer.FieldOsversion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsversion(v) + return nil + case bouncer.FieldFeatureflags: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeatureflags(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3222,12 +3407,6 @@ func (m *BouncerMutation) AddField(name string, value ent.Value) error { // mutation. func (m *BouncerMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(bouncer.FieldCreatedAt) { - fields = append(fields, bouncer.FieldCreatedAt) - } - if m.FieldCleared(bouncer.FieldUpdatedAt) { - fields = append(fields, bouncer.FieldUpdatedAt) - } if m.FieldCleared(bouncer.FieldIPAddress) { fields = append(fields, bouncer.FieldIPAddress) } @@ -3237,8 +3416,17 @@ func (m *BouncerMutation) ClearedFields() []string { if m.FieldCleared(bouncer.FieldVersion) { fields = append(fields, bouncer.FieldVersion) } - if m.FieldCleared(bouncer.FieldUntil) { - fields = append(fields, bouncer.FieldUntil) + if m.FieldCleared(bouncer.FieldLastPull) { + fields = append(fields, bouncer.FieldLastPull) + } + if m.FieldCleared(bouncer.FieldOsname) { + fields = append(fields, bouncer.FieldOsname) + } + if m.FieldCleared(bouncer.FieldOsversion) { + fields = append(fields, bouncer.FieldOsversion) + } + if m.FieldCleared(bouncer.FieldFeatureflags) { + fields = append(fields, bouncer.FieldFeatureflags) } return fields } @@ -3254,12 +3442,6 @@ func (m *BouncerMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *BouncerMutation) ClearField(name string) error { switch name { - case bouncer.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case bouncer.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case bouncer.FieldIPAddress: m.ClearIPAddress() return nil @@ -3269,8 +3451,17 @@ func (m *BouncerMutation) ClearField(name string) error { case bouncer.FieldVersion: m.ClearVersion() return nil - case bouncer.FieldUntil: - m.ClearUntil() + case bouncer.FieldLastPull: + m.ClearLastPull() + return nil + case bouncer.FieldOsname: + m.ClearOsname() + return nil + case bouncer.FieldOsversion: + m.ClearOsversion() + return nil + case bouncer.FieldFeatureflags: + m.ClearFeatureflags() return nil } return fmt.Errorf("unknown Bouncer nullable field %s", name) @@ -3304,15 +3495,21 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldVersion: m.ResetVersion() return nil - case bouncer.FieldUntil: - m.ResetUntil() - return nil case bouncer.FieldLastPull: m.ResetLastPull() return nil case bouncer.FieldAuthType: m.ResetAuthType() return nil + case bouncer.FieldOsname: + m.ResetOsname() + return nil + case bouncer.FieldOsversion: + m.ResetOsversion() + return nil + case bouncer.FieldFeatureflags: + m.ResetFeatureflags() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3496,7 +3693,7 @@ func (m *ConfigItemMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the ConfigItem entity. // If the ConfigItem object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -3510,22 +3707,9 @@ func (m *ConfigItemMutation) OldCreatedAt(ctx context.Context) (v *time.Time, er return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *ConfigItemMutation) ClearCreatedAt() { +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ConfigItemMutation) ResetCreatedAt() { m.created_at = nil - m.clearedFields[configitem.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *ConfigItemMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[configitem.FieldCreatedAt] - return ok -} - -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *ConfigItemMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, configitem.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -3545,7 +3729,7 @@ func (m *ConfigItemMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the ConfigItem entity. // If the ConfigItem object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -3559,22 +3743,9 @@ func (m *ConfigItemMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, er return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *ConfigItemMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[configitem.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *ConfigItemMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[configitem.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *ConfigItemMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, configitem.FieldUpdatedAt) } // SetName sets the "name" field. @@ -3654,11 +3825,26 @@ func (m *ConfigItemMutation) Where(ps ...predicate.ConfigItem) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the ConfigItemMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ConfigItemMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ConfigItem, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *ConfigItemMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *ConfigItemMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (ConfigItem). func (m *ConfigItemMutation) Type() string { return m.typ @@ -3780,14 +3966,7 @@ func (m *ConfigItemMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *ConfigItemMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(configitem.FieldCreatedAt) { - fields = append(fields, configitem.FieldCreatedAt) - } - if m.FieldCleared(configitem.FieldUpdatedAt) { - fields = append(fields, configitem.FieldUpdatedAt) - } - return fields + return nil } // FieldCleared returns a boolean indicating if a field with the given name was @@ -3800,14 +3979,6 @@ func (m *ConfigItemMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *ConfigItemMutation) ClearField(name string) error { - switch name { - case configitem.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case configitem.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - } return fmt.Errorf("unknown ConfigItem nullable field %s", name) } @@ -4028,7 +4199,7 @@ func (m *DecisionMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Decision entity. // If the Decision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -4042,22 +4213,9 @@ func (m *DecisionMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *DecisionMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[decision.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *DecisionMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[decision.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *DecisionMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, decision.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -4077,7 +4235,7 @@ func (m *DecisionMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Decision entity. // If the Decision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -4091,22 +4249,9 @@ func (m *DecisionMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *DecisionMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[decision.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *DecisionMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[decision.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *DecisionMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, decision.FieldUpdatedAt) } // SetUntil sets the "until" field. @@ -4830,6 +4975,7 @@ func (m *DecisionMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *DecisionMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[decision.FieldAlertDecisions] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -4866,11 +5012,26 @@ func (m *DecisionMutation) Where(ps ...predicate.Decision) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Decision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DecisionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DecisionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Decision). func (m *DecisionMutation) Type() string { return m.typ @@ -5224,12 +5385,6 @@ func (m *DecisionMutation) AddField(name string, value ent.Value) error { // mutation. func (m *DecisionMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(decision.FieldCreatedAt) { - fields = append(fields, decision.FieldCreatedAt) - } - if m.FieldCleared(decision.FieldUpdatedAt) { - fields = append(fields, decision.FieldUpdatedAt) - } if m.FieldCleared(decision.FieldUntil) { fields = append(fields, decision.FieldUntil) } @@ -5268,12 +5423,6 @@ func (m *DecisionMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *DecisionMutation) ClearField(name string) error { switch name { - case decision.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case decision.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case decision.FieldUntil: m.ClearUntil() return nil @@ -5565,7 +5714,7 @@ func (m *EventMutation) CreatedAt() (r time.Time, exists bool) { // OldCreatedAt returns the old "created_at" field's value of the Event entity. // If the Event object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *EventMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *EventMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -5579,22 +5728,9 @@ func (m *EventMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err err return oldValue.CreatedAt, nil } -// ClearCreatedAt clears the value of the "created_at" field. -func (m *EventMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[event.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *EventMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[event.FieldCreatedAt] - return ok -} - // ResetCreatedAt resets all changes to the "created_at" field. func (m *EventMutation) ResetCreatedAt() { m.created_at = nil - delete(m.clearedFields, event.FieldCreatedAt) } // SetUpdatedAt sets the "updated_at" field. @@ -5614,7 +5750,7 @@ func (m *EventMutation) UpdatedAt() (r time.Time, exists bool) { // OldUpdatedAt returns the old "updated_at" field's value of the Event entity. // If the Event object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -5628,22 +5764,9 @@ func (m *EventMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err err return oldValue.UpdatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *EventMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[event.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *EventMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[event.FieldUpdatedAt] - return ok -} - // ResetUpdatedAt resets all changes to the "updated_at" field. func (m *EventMutation) ResetUpdatedAt() { m.updated_at = nil - delete(m.clearedFields, event.FieldUpdatedAt) } // SetTime sets the "time" field. @@ -5775,6 +5898,7 @@ func (m *EventMutation) SetOwnerID(id int) { // ClearOwner clears the "owner" edge to the Alert entity. func (m *EventMutation) ClearOwner() { m.clearedowner = true + m.clearedFields[event.FieldAlertEvents] = struct{}{} } // OwnerCleared reports if the "owner" edge to the Alert entity was cleared. @@ -5811,11 +5935,26 @@ func (m *EventMutation) Where(ps ...predicate.Event) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the EventMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *EventMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Event, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *EventMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *EventMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Event). func (m *EventMutation) Type() string { return m.typ @@ -5955,12 +6094,6 @@ func (m *EventMutation) AddField(name string, value ent.Value) error { // mutation. func (m *EventMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(event.FieldCreatedAt) { - fields = append(fields, event.FieldCreatedAt) - } - if m.FieldCleared(event.FieldUpdatedAt) { - fields = append(fields, event.FieldUpdatedAt) - } if m.FieldCleared(event.FieldAlertEvents) { fields = append(fields, event.FieldAlertEvents) } @@ -5978,12 +6111,6 @@ func (m *EventMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *EventMutation) ClearField(name string) error { switch name { - case event.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case event.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil case event.FieldAlertEvents: m.ClearAlertEvents() return nil @@ -6088,44 +6215,31 @@ func (m *EventMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Event edge %s", name) } -// MachineMutation represents an operation that mutates the Machine nodes in the graph. -type MachineMutation struct { +// LockMutation represents an operation that mutates the Lock nodes in the graph. +type LockMutation struct { config - op Op - typ string - id *int - created_at *time.Time - updated_at *time.Time - last_push *time.Time - last_heartbeat *time.Time - machineId *string - password *string - ipAddress *string - scenarios *string - version *string - isValidated *bool - status *string - auth_type *string - clearedFields map[string]struct{} - alerts map[int]struct{} - removedalerts map[int]struct{} - clearedalerts bool - done bool - oldValue func(context.Context) (*Machine, error) - predicates []predicate.Machine + op Op + typ string + id *int + name *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Lock, error) + predicates []predicate.Lock } -var _ ent.Mutation = (*MachineMutation)(nil) +var _ ent.Mutation = (*LockMutation)(nil) -// machineOption allows management of the mutation configuration using functional options. -type machineOption func(*MachineMutation) +// lockOption allows management of the mutation configuration using functional options. +type lockOption func(*LockMutation) -// newMachineMutation creates new mutation for the Machine entity. -func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation { - m := &MachineMutation{ +// newLockMutation creates new mutation for the Lock entity. +func newLockMutation(c config, op Op, opts ...lockOption) *LockMutation { + m := &LockMutation{ config: c, op: op, - typ: TypeMachine, + typ: TypeLock, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -6134,20 +6248,20 @@ func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation return m } -// withMachineID sets the ID field of the mutation. -func withMachineID(id int) machineOption { - return func(m *MachineMutation) { +// withLockID sets the ID field of the mutation. +func withLockID(id int) lockOption { + return func(m *LockMutation) { var ( err error once sync.Once - value *Machine + value *Lock ) - m.oldValue = func(ctx context.Context) (*Machine, error) { + m.oldValue = func(ctx context.Context) (*Lock, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Machine.Get(ctx, id) + value, err = m.Client().Lock.Get(ctx, id) } }) return value, err @@ -6156,10 +6270,10 @@ func withMachineID(id int) machineOption { } } -// withMachine sets the old Machine of the mutation. -func withMachine(node *Machine) machineOption { - return func(m *MachineMutation) { - m.oldValue = func(context.Context) (*Machine, error) { +// withLock sets the old Lock of the mutation. +func withLock(node *Lock) lockOption { + return func(m *LockMutation) { + m.oldValue = func(context.Context) (*Lock, error) { return node, nil } m.id = &node.ID @@ -6168,7 +6282,7 @@ func withMachine(node *Machine) machineOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m MachineMutation) Client() *Client { +func (m LockMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -6176,7 +6290,7 @@ func (m MachineMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m MachineMutation) Tx() (*Tx, error) { +func (m LockMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -6187,7 +6301,7 @@ func (m MachineMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *MachineMutation) ID() (id int, exists bool) { +func (m *LockMutation) ID() (id int, exists bool) { if m.id == nil { return } @@ -6198,7 +6312,7 @@ func (m *MachineMutation) ID() (id int, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { +func (m *LockMutation) IDs(ctx context.Context) ([]int, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -6207,228 +6321,599 @@ func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Machine.Query().Where(m.predicates...).IDs(ctx) + return m.Client().Lock.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetCreatedAt sets the "created_at" field. -func (m *MachineMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetName sets the "name" field. +func (m *LockMutation) SetName(s string) { + m.name = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *MachineMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// Name returns the value of the "name" field in the mutation. +func (m *LockMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *LockMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.CreatedAt, nil -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (m *MachineMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[machine.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *MachineMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[machine.FieldCreatedAt] - return ok + return oldValue.Name, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *MachineMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, machine.FieldCreatedAt) +// ResetName resets all changes to the "name" field. +func (m *LockMutation) ResetName() { + m.name = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *MachineMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetCreatedAt sets the "created_at" field. +func (m *LockMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *MachineMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *LockMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the Lock entity. +// If the Lock object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *LockMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.CreatedAt, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *MachineMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[machine.FieldUpdatedAt] = struct{}{} +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *LockMutation) ResetCreatedAt() { + m.created_at = nil } -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *MachineMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[machine.FieldUpdatedAt] - return ok +// Where appends a list predicates to the LockMutation builder. +func (m *LockMutation) Where(ps ...predicate.Lock) { + m.predicates = append(m.predicates, ps...) } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *MachineMutation) ResetUpdatedAt() { - m.updated_at = nil - delete(m.clearedFields, machine.FieldUpdatedAt) +// WhereP appends storage-level predicates to the LockMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *LockMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Lock, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SetLastPush sets the "last_push" field. -func (m *MachineMutation) SetLastPush(t time.Time) { - m.last_push = &t +// Op returns the operation name. +func (m *LockMutation) Op() Op { + return m.op } -// LastPush returns the value of the "last_push" field in the mutation. -func (m *MachineMutation) LastPush() (r time.Time, exists bool) { - v := m.last_push - if v == nil { - return - } - return *v, true +// SetOp allows setting the mutation operation. +func (m *LockMutation) SetOp(op Op) { + m.op = op } -// OldLastPush returns the old "last_push" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldLastPush(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLastPush is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLastPush requires an ID field in the mutation") +// Type returns the node type of this mutation (Lock). +func (m *LockMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *LockMutation) Fields() []string { + fields := make([]string, 0, 2) + if m.name != nil { + fields = append(fields, lock.FieldName) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLastPush: %w", err) + if m.created_at != nil { + fields = append(fields, lock.FieldCreatedAt) } - return oldValue.LastPush, nil + return fields } -// ClearLastPush clears the value of the "last_push" field. -func (m *MachineMutation) ClearLastPush() { - m.last_push = nil - m.clearedFields[machine.FieldLastPush] = struct{}{} +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *LockMutation) Field(name string) (ent.Value, bool) { + switch name { + case lock.FieldName: + return m.Name() + case lock.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false } -// LastPushCleared returns if the "last_push" field was cleared in this mutation. -func (m *MachineMutation) LastPushCleared() bool { - _, ok := m.clearedFields[machine.FieldLastPush] - return ok +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *LockMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case lock.FieldName: + return m.OldName(ctx) + case lock.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown Lock field %s", name) } -// ResetLastPush resets all changes to the "last_push" field. -func (m *MachineMutation) ResetLastPush() { - m.last_push = nil - delete(m.clearedFields, machine.FieldLastPush) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) SetField(name string, value ent.Value) error { + switch name { + case lock.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case lock.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown Lock field %s", name) } -// SetLastHeartbeat sets the "last_heartbeat" field. -func (m *MachineMutation) SetLastHeartbeat(t time.Time) { - m.last_heartbeat = &t +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *LockMutation) AddedFields() []string { + return nil } -// LastHeartbeat returns the value of the "last_heartbeat" field in the mutation. -func (m *MachineMutation) LastHeartbeat() (r time.Time, exists bool) { - v := m.last_heartbeat - if v == nil { - return +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *LockMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *LockMutation) AddField(name string, value ent.Value) error { + switch name { } - return *v, true + return fmt.Errorf("unknown Lock numeric field %s", name) } -// OldLastHeartbeat returns the old "last_heartbeat" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldLastHeartbeat(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLastHeartbeat is only allowed on UpdateOne operations") +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *LockMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *LockMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *LockMutation) ClearField(name string) error { + return fmt.Errorf("unknown Lock nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *LockMutation) ResetField(name string) error { + switch name { + case lock.FieldName: + m.ResetName() + return nil + case lock.FieldCreatedAt: + m.ResetCreatedAt() + return nil } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLastHeartbeat requires an ID field in the mutation") + return fmt.Errorf("unknown Lock field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *LockMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *LockMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *LockMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *LockMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *LockMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *LockMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *LockMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Lock unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *LockMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Lock edge %s", name) +} + +// MachineMutation represents an operation that mutates the Machine nodes in the graph. +type MachineMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + last_push *time.Time + last_heartbeat *time.Time + machineId *string + password *string + ipAddress *string + scenarios *string + version *string + isValidated *bool + auth_type *string + osname *string + osversion *string + featureflags *string + hubstate *map[string][]schema.ItemState + datasources *map[string]int64 + clearedFields map[string]struct{} + alerts map[int]struct{} + removedalerts map[int]struct{} + clearedalerts bool + done bool + oldValue func(context.Context) (*Machine, error) + predicates []predicate.Machine +} + +var _ ent.Mutation = (*MachineMutation)(nil) + +// machineOption allows management of the mutation configuration using functional options. +type machineOption func(*MachineMutation) + +// newMachineMutation creates new mutation for the Machine entity. +func newMachineMutation(c config, op Op, opts ...machineOption) *MachineMutation { + m := &MachineMutation{ + config: c, + op: op, + typ: TypeMachine, + clearedFields: make(map[string]struct{}), } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLastHeartbeat: %w", err) + for _, opt := range opts { + opt(m) } - return oldValue.LastHeartbeat, nil + return m } -// ClearLastHeartbeat clears the value of the "last_heartbeat" field. -func (m *MachineMutation) ClearLastHeartbeat() { - m.last_heartbeat = nil - m.clearedFields[machine.FieldLastHeartbeat] = struct{}{} +// withMachineID sets the ID field of the mutation. +func withMachineID(id int) machineOption { + return func(m *MachineMutation) { + var ( + err error + once sync.Once + value *Machine + ) + m.oldValue = func(ctx context.Context) (*Machine, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Machine.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// LastHeartbeatCleared returns if the "last_heartbeat" field was cleared in this mutation. -func (m *MachineMutation) LastHeartbeatCleared() bool { - _, ok := m.clearedFields[machine.FieldLastHeartbeat] - return ok +// withMachine sets the old Machine of the mutation. +func withMachine(node *Machine) machineOption { + return func(m *MachineMutation) { + m.oldValue = func(context.Context) (*Machine, error) { + return node, nil + } + m.id = &node.ID + } } -// ResetLastHeartbeat resets all changes to the "last_heartbeat" field. -func (m *MachineMutation) ResetLastHeartbeat() { - m.last_heartbeat = nil - delete(m.clearedFields, machine.FieldLastHeartbeat) +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MachineMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// SetMachineId sets the "machineId" field. -func (m *MachineMutation) SetMachineId(s string) { - m.machineId = &s +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MachineMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// MachineId returns the value of the "machineId" field in the mutation. -func (m *MachineMutation) MachineId() (r string, exists bool) { - v := m.machineId +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MachineMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MachineMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Machine.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *MachineMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *MachineMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldMachineId returns the old "machineId" field's value of the Machine entity. +// OldCreatedAt returns the old "created_at" field's value of the Machine entity. // If the Machine object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldMachineId(ctx context.Context) (v string, err error) { +func (m *MachineMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMachineId is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *MachineMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *MachineMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *MachineMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *MachineMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetLastPush sets the "last_push" field. +func (m *MachineMutation) SetLastPush(t time.Time) { + m.last_push = &t +} + +// LastPush returns the value of the "last_push" field in the mutation. +func (m *MachineMutation) LastPush() (r time.Time, exists bool) { + v := m.last_push + if v == nil { + return + } + return *v, true +} + +// OldLastPush returns the old "last_push" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldLastPush(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastPush is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastPush requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastPush: %w", err) + } + return oldValue.LastPush, nil +} + +// ClearLastPush clears the value of the "last_push" field. +func (m *MachineMutation) ClearLastPush() { + m.last_push = nil + m.clearedFields[machine.FieldLastPush] = struct{}{} +} + +// LastPushCleared returns if the "last_push" field was cleared in this mutation. +func (m *MachineMutation) LastPushCleared() bool { + _, ok := m.clearedFields[machine.FieldLastPush] + return ok +} + +// ResetLastPush resets all changes to the "last_push" field. +func (m *MachineMutation) ResetLastPush() { + m.last_push = nil + delete(m.clearedFields, machine.FieldLastPush) +} + +// SetLastHeartbeat sets the "last_heartbeat" field. +func (m *MachineMutation) SetLastHeartbeat(t time.Time) { + m.last_heartbeat = &t +} + +// LastHeartbeat returns the value of the "last_heartbeat" field in the mutation. +func (m *MachineMutation) LastHeartbeat() (r time.Time, exists bool) { + v := m.last_heartbeat + if v == nil { + return + } + return *v, true +} + +// OldLastHeartbeat returns the old "last_heartbeat" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldLastHeartbeat(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastHeartbeat is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastHeartbeat requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastHeartbeat: %w", err) + } + return oldValue.LastHeartbeat, nil +} + +// ClearLastHeartbeat clears the value of the "last_heartbeat" field. +func (m *MachineMutation) ClearLastHeartbeat() { + m.last_heartbeat = nil + m.clearedFields[machine.FieldLastHeartbeat] = struct{}{} +} + +// LastHeartbeatCleared returns if the "last_heartbeat" field was cleared in this mutation. +func (m *MachineMutation) LastHeartbeatCleared() bool { + _, ok := m.clearedFields[machine.FieldLastHeartbeat] + return ok +} + +// ResetLastHeartbeat resets all changes to the "last_heartbeat" field. +func (m *MachineMutation) ResetLastHeartbeat() { + m.last_heartbeat = nil + delete(m.clearedFields, machine.FieldLastHeartbeat) +} + +// SetMachineId sets the "machineId" field. +func (m *MachineMutation) SetMachineId(s string) { + m.machineId = &s +} + +// MachineId returns the value of the "machineId" field in the mutation. +func (m *MachineMutation) MachineId() (r string, exists bool) { + v := m.machineId + if v == nil { + return + } + return *v, true +} + +// OldMachineId returns the old "machineId" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldMachineId(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMachineId is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { return v, errors.New("OldMachineId requires an ID field in the mutation") @@ -6456,395 +6941,1461 @@ func (m *MachineMutation) Password() (r string, exists bool) { if v == nil { return } - return *v, true + return *v, true +} + +// OldPassword returns the old "password" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldPassword(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassword is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassword requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassword: %w", err) + } + return oldValue.Password, nil +} + +// ResetPassword resets all changes to the "password" field. +func (m *MachineMutation) ResetPassword() { + m.password = nil +} + +// SetIpAddress sets the "ipAddress" field. +func (m *MachineMutation) SetIpAddress(s string) { + m.ipAddress = &s +} + +// IpAddress returns the value of the "ipAddress" field in the mutation. +func (m *MachineMutation) IpAddress() (r string, exists bool) { + v := m.ipAddress + if v == nil { + return + } + return *v, true +} + +// OldIpAddress returns the old "ipAddress" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldIpAddress(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIpAddress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIpAddress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIpAddress: %w", err) + } + return oldValue.IpAddress, nil +} + +// ResetIpAddress resets all changes to the "ipAddress" field. +func (m *MachineMutation) ResetIpAddress() { + m.ipAddress = nil +} + +// SetScenarios sets the "scenarios" field. +func (m *MachineMutation) SetScenarios(s string) { + m.scenarios = &s +} + +// Scenarios returns the value of the "scenarios" field in the mutation. +func (m *MachineMutation) Scenarios() (r string, exists bool) { + v := m.scenarios + if v == nil { + return + } + return *v, true +} + +// OldScenarios returns the old "scenarios" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldScenarios(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScenarios is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScenarios requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScenarios: %w", err) + } + return oldValue.Scenarios, nil +} + +// ClearScenarios clears the value of the "scenarios" field. +func (m *MachineMutation) ClearScenarios() { + m.scenarios = nil + m.clearedFields[machine.FieldScenarios] = struct{}{} +} + +// ScenariosCleared returns if the "scenarios" field was cleared in this mutation. +func (m *MachineMutation) ScenariosCleared() bool { + _, ok := m.clearedFields[machine.FieldScenarios] + return ok +} + +// ResetScenarios resets all changes to the "scenarios" field. +func (m *MachineMutation) ResetScenarios() { + m.scenarios = nil + delete(m.clearedFields, machine.FieldScenarios) +} + +// SetVersion sets the "version" field. +func (m *MachineMutation) SetVersion(s string) { + m.version = &s +} + +// Version returns the value of the "version" field in the mutation. +func (m *MachineMutation) Version() (r string, exists bool) { + v := m.version + if v == nil { + return + } + return *v, true +} + +// OldVersion returns the old "version" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldVersion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVersion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVersion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVersion: %w", err) + } + return oldValue.Version, nil +} + +// ClearVersion clears the value of the "version" field. +func (m *MachineMutation) ClearVersion() { + m.version = nil + m.clearedFields[machine.FieldVersion] = struct{}{} +} + +// VersionCleared returns if the "version" field was cleared in this mutation. +func (m *MachineMutation) VersionCleared() bool { + _, ok := m.clearedFields[machine.FieldVersion] + return ok +} + +// ResetVersion resets all changes to the "version" field. +func (m *MachineMutation) ResetVersion() { + m.version = nil + delete(m.clearedFields, machine.FieldVersion) +} + +// SetIsValidated sets the "isValidated" field. +func (m *MachineMutation) SetIsValidated(b bool) { + m.isValidated = &b +} + +// IsValidated returns the value of the "isValidated" field in the mutation. +func (m *MachineMutation) IsValidated() (r bool, exists bool) { + v := m.isValidated + if v == nil { + return + } + return *v, true +} + +// OldIsValidated returns the old "isValidated" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldIsValidated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsValidated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsValidated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsValidated: %w", err) + } + return oldValue.IsValidated, nil +} + +// ResetIsValidated resets all changes to the "isValidated" field. +func (m *MachineMutation) ResetIsValidated() { + m.isValidated = nil +} + +// SetAuthType sets the "auth_type" field. +func (m *MachineMutation) SetAuthType(s string) { + m.auth_type = &s +} + +// AuthType returns the value of the "auth_type" field in the mutation. +func (m *MachineMutation) AuthType() (r string, exists bool) { + v := m.auth_type + if v == nil { + return + } + return *v, true +} + +// OldAuthType returns the old "auth_type" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + } + return oldValue.AuthType, nil +} + +// ResetAuthType resets all changes to the "auth_type" field. +func (m *MachineMutation) ResetAuthType() { + m.auth_type = nil +} + +// SetOsname sets the "osname" field. +func (m *MachineMutation) SetOsname(s string) { + m.osname = &s +} + +// Osname returns the value of the "osname" field in the mutation. +func (m *MachineMutation) Osname() (r string, exists bool) { + v := m.osname + if v == nil { + return + } + return *v, true +} + +// OldOsname returns the old "osname" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldOsname(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsname is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsname requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsname: %w", err) + } + return oldValue.Osname, nil +} + +// ClearOsname clears the value of the "osname" field. +func (m *MachineMutation) ClearOsname() { + m.osname = nil + m.clearedFields[machine.FieldOsname] = struct{}{} +} + +// OsnameCleared returns if the "osname" field was cleared in this mutation. +func (m *MachineMutation) OsnameCleared() bool { + _, ok := m.clearedFields[machine.FieldOsname] + return ok +} + +// ResetOsname resets all changes to the "osname" field. +func (m *MachineMutation) ResetOsname() { + m.osname = nil + delete(m.clearedFields, machine.FieldOsname) +} + +// SetOsversion sets the "osversion" field. +func (m *MachineMutation) SetOsversion(s string) { + m.osversion = &s +} + +// Osversion returns the value of the "osversion" field in the mutation. +func (m *MachineMutation) Osversion() (r string, exists bool) { + v := m.osversion + if v == nil { + return + } + return *v, true +} + +// OldOsversion returns the old "osversion" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldOsversion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOsversion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOsversion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOsversion: %w", err) + } + return oldValue.Osversion, nil +} + +// ClearOsversion clears the value of the "osversion" field. +func (m *MachineMutation) ClearOsversion() { + m.osversion = nil + m.clearedFields[machine.FieldOsversion] = struct{}{} +} + +// OsversionCleared returns if the "osversion" field was cleared in this mutation. +func (m *MachineMutation) OsversionCleared() bool { + _, ok := m.clearedFields[machine.FieldOsversion] + return ok +} + +// ResetOsversion resets all changes to the "osversion" field. +func (m *MachineMutation) ResetOsversion() { + m.osversion = nil + delete(m.clearedFields, machine.FieldOsversion) +} + +// SetFeatureflags sets the "featureflags" field. +func (m *MachineMutation) SetFeatureflags(s string) { + m.featureflags = &s +} + +// Featureflags returns the value of the "featureflags" field in the mutation. +func (m *MachineMutation) Featureflags() (r string, exists bool) { + v := m.featureflags + if v == nil { + return + } + return *v, true +} + +// OldFeatureflags returns the old "featureflags" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldFeatureflags(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeatureflags is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeatureflags requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeatureflags: %w", err) + } + return oldValue.Featureflags, nil +} + +// ClearFeatureflags clears the value of the "featureflags" field. +func (m *MachineMutation) ClearFeatureflags() { + m.featureflags = nil + m.clearedFields[machine.FieldFeatureflags] = struct{}{} +} + +// FeatureflagsCleared returns if the "featureflags" field was cleared in this mutation. +func (m *MachineMutation) FeatureflagsCleared() bool { + _, ok := m.clearedFields[machine.FieldFeatureflags] + return ok +} + +// ResetFeatureflags resets all changes to the "featureflags" field. +func (m *MachineMutation) ResetFeatureflags() { + m.featureflags = nil + delete(m.clearedFields, machine.FieldFeatureflags) +} + +// SetHubstate sets the "hubstate" field. +func (m *MachineMutation) SetHubstate(ms map[string][]schema.ItemState) { + m.hubstate = &ms +} + +// Hubstate returns the value of the "hubstate" field in the mutation. +func (m *MachineMutation) Hubstate() (r map[string][]schema.ItemState, exists bool) { + v := m.hubstate + if v == nil { + return + } + return *v, true +} + +// OldHubstate returns the old "hubstate" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldHubstate(ctx context.Context) (v map[string][]schema.ItemState, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldHubstate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldHubstate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldHubstate: %w", err) + } + return oldValue.Hubstate, nil +} + +// ClearHubstate clears the value of the "hubstate" field. +func (m *MachineMutation) ClearHubstate() { + m.hubstate = nil + m.clearedFields[machine.FieldHubstate] = struct{}{} +} + +// HubstateCleared returns if the "hubstate" field was cleared in this mutation. +func (m *MachineMutation) HubstateCleared() bool { + _, ok := m.clearedFields[machine.FieldHubstate] + return ok +} + +// ResetHubstate resets all changes to the "hubstate" field. +func (m *MachineMutation) ResetHubstate() { + m.hubstate = nil + delete(m.clearedFields, machine.FieldHubstate) +} + +// SetDatasources sets the "datasources" field. +func (m *MachineMutation) SetDatasources(value map[string]int64) { + m.datasources = &value +} + +// Datasources returns the value of the "datasources" field in the mutation. +func (m *MachineMutation) Datasources() (r map[string]int64, exists bool) { + v := m.datasources + if v == nil { + return + } + return *v, true +} + +// OldDatasources returns the old "datasources" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldDatasources(ctx context.Context) (v map[string]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDatasources is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDatasources requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDatasources: %w", err) + } + return oldValue.Datasources, nil +} + +// ClearDatasources clears the value of the "datasources" field. +func (m *MachineMutation) ClearDatasources() { + m.datasources = nil + m.clearedFields[machine.FieldDatasources] = struct{}{} +} + +// DatasourcesCleared returns if the "datasources" field was cleared in this mutation. +func (m *MachineMutation) DatasourcesCleared() bool { + _, ok := m.clearedFields[machine.FieldDatasources] + return ok +} + +// ResetDatasources resets all changes to the "datasources" field. +func (m *MachineMutation) ResetDatasources() { + m.datasources = nil + delete(m.clearedFields, machine.FieldDatasources) +} + +// AddAlertIDs adds the "alerts" edge to the Alert entity by ids. +func (m *MachineMutation) AddAlertIDs(ids ...int) { + if m.alerts == nil { + m.alerts = make(map[int]struct{}) + } + for i := range ids { + m.alerts[ids[i]] = struct{}{} + } +} + +// ClearAlerts clears the "alerts" edge to the Alert entity. +func (m *MachineMutation) ClearAlerts() { + m.clearedalerts = true +} + +// AlertsCleared reports if the "alerts" edge to the Alert entity was cleared. +func (m *MachineMutation) AlertsCleared() bool { + return m.clearedalerts +} + +// RemoveAlertIDs removes the "alerts" edge to the Alert entity by IDs. +func (m *MachineMutation) RemoveAlertIDs(ids ...int) { + if m.removedalerts == nil { + m.removedalerts = make(map[int]struct{}) + } + for i := range ids { + delete(m.alerts, ids[i]) + m.removedalerts[ids[i]] = struct{}{} + } +} + +// RemovedAlerts returns the removed IDs of the "alerts" edge to the Alert entity. +func (m *MachineMutation) RemovedAlertsIDs() (ids []int) { + for id := range m.removedalerts { + ids = append(ids, id) + } + return +} + +// AlertsIDs returns the "alerts" edge IDs in the mutation. +func (m *MachineMutation) AlertsIDs() (ids []int) { + for id := range m.alerts { + ids = append(ids, id) + } + return +} + +// ResetAlerts resets all changes to the "alerts" edge. +func (m *MachineMutation) ResetAlerts() { + m.alerts = nil + m.clearedalerts = false + m.removedalerts = nil +} + +// Where appends a list predicates to the MachineMutation builder. +func (m *MachineMutation) Where(ps ...predicate.Machine) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the MachineMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MachineMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Machine, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *MachineMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *MachineMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Machine). +func (m *MachineMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *MachineMutation) Fields() []string { + fields := make([]string, 0, 16) + if m.created_at != nil { + fields = append(fields, machine.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, machine.FieldUpdatedAt) + } + if m.last_push != nil { + fields = append(fields, machine.FieldLastPush) + } + if m.last_heartbeat != nil { + fields = append(fields, machine.FieldLastHeartbeat) + } + if m.machineId != nil { + fields = append(fields, machine.FieldMachineId) + } + if m.password != nil { + fields = append(fields, machine.FieldPassword) + } + if m.ipAddress != nil { + fields = append(fields, machine.FieldIpAddress) + } + if m.scenarios != nil { + fields = append(fields, machine.FieldScenarios) + } + if m.version != nil { + fields = append(fields, machine.FieldVersion) + } + if m.isValidated != nil { + fields = append(fields, machine.FieldIsValidated) + } + if m.auth_type != nil { + fields = append(fields, machine.FieldAuthType) + } + if m.osname != nil { + fields = append(fields, machine.FieldOsname) + } + if m.osversion != nil { + fields = append(fields, machine.FieldOsversion) + } + if m.featureflags != nil { + fields = append(fields, machine.FieldFeatureflags) + } + if m.hubstate != nil { + fields = append(fields, machine.FieldHubstate) + } + if m.datasources != nil { + fields = append(fields, machine.FieldDatasources) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *MachineMutation) Field(name string) (ent.Value, bool) { + switch name { + case machine.FieldCreatedAt: + return m.CreatedAt() + case machine.FieldUpdatedAt: + return m.UpdatedAt() + case machine.FieldLastPush: + return m.LastPush() + case machine.FieldLastHeartbeat: + return m.LastHeartbeat() + case machine.FieldMachineId: + return m.MachineId() + case machine.FieldPassword: + return m.Password() + case machine.FieldIpAddress: + return m.IpAddress() + case machine.FieldScenarios: + return m.Scenarios() + case machine.FieldVersion: + return m.Version() + case machine.FieldIsValidated: + return m.IsValidated() + case machine.FieldAuthType: + return m.AuthType() + case machine.FieldOsname: + return m.Osname() + case machine.FieldOsversion: + return m.Osversion() + case machine.FieldFeatureflags: + return m.Featureflags() + case machine.FieldHubstate: + return m.Hubstate() + case machine.FieldDatasources: + return m.Datasources() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case machine.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case machine.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case machine.FieldLastPush: + return m.OldLastPush(ctx) + case machine.FieldLastHeartbeat: + return m.OldLastHeartbeat(ctx) + case machine.FieldMachineId: + return m.OldMachineId(ctx) + case machine.FieldPassword: + return m.OldPassword(ctx) + case machine.FieldIpAddress: + return m.OldIpAddress(ctx) + case machine.FieldScenarios: + return m.OldScenarios(ctx) + case machine.FieldVersion: + return m.OldVersion(ctx) + case machine.FieldIsValidated: + return m.OldIsValidated(ctx) + case machine.FieldAuthType: + return m.OldAuthType(ctx) + case machine.FieldOsname: + return m.OldOsname(ctx) + case machine.FieldOsversion: + return m.OldOsversion(ctx) + case machine.FieldFeatureflags: + return m.OldFeatureflags(ctx) + case machine.FieldHubstate: + return m.OldHubstate(ctx) + case machine.FieldDatasources: + return m.OldDatasources(ctx) + } + return nil, fmt.Errorf("unknown Machine field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MachineMutation) SetField(name string, value ent.Value) error { + switch name { + case machine.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case machine.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case machine.FieldLastPush: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastPush(v) + return nil + case machine.FieldLastHeartbeat: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastHeartbeat(v) + return nil + case machine.FieldMachineId: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMachineId(v) + return nil + case machine.FieldPassword: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassword(v) + return nil + case machine.FieldIpAddress: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIpAddress(v) + return nil + case machine.FieldScenarios: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScenarios(v) + return nil + case machine.FieldVersion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVersion(v) + return nil + case machine.FieldIsValidated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsValidated(v) + return nil + case machine.FieldAuthType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthType(v) + return nil + case machine.FieldOsname: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsname(v) + return nil + case machine.FieldOsversion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOsversion(v) + return nil + case machine.FieldFeatureflags: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeatureflags(v) + return nil + case machine.FieldHubstate: + v, ok := value.(map[string][]schema.ItemState) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetHubstate(v) + return nil + case machine.FieldDatasources: + v, ok := value.(map[string]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDatasources(v) + return nil + } + return fmt.Errorf("unknown Machine field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *MachineMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *MachineMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MachineMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Machine numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *MachineMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(machine.FieldLastPush) { + fields = append(fields, machine.FieldLastPush) + } + if m.FieldCleared(machine.FieldLastHeartbeat) { + fields = append(fields, machine.FieldLastHeartbeat) + } + if m.FieldCleared(machine.FieldScenarios) { + fields = append(fields, machine.FieldScenarios) + } + if m.FieldCleared(machine.FieldVersion) { + fields = append(fields, machine.FieldVersion) + } + if m.FieldCleared(machine.FieldOsname) { + fields = append(fields, machine.FieldOsname) + } + if m.FieldCleared(machine.FieldOsversion) { + fields = append(fields, machine.FieldOsversion) + } + if m.FieldCleared(machine.FieldFeatureflags) { + fields = append(fields, machine.FieldFeatureflags) + } + if m.FieldCleared(machine.FieldHubstate) { + fields = append(fields, machine.FieldHubstate) + } + if m.FieldCleared(machine.FieldDatasources) { + fields = append(fields, machine.FieldDatasources) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *MachineMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *MachineMutation) ClearField(name string) error { + switch name { + case machine.FieldLastPush: + m.ClearLastPush() + return nil + case machine.FieldLastHeartbeat: + m.ClearLastHeartbeat() + return nil + case machine.FieldScenarios: + m.ClearScenarios() + return nil + case machine.FieldVersion: + m.ClearVersion() + return nil + case machine.FieldOsname: + m.ClearOsname() + return nil + case machine.FieldOsversion: + m.ClearOsversion() + return nil + case machine.FieldFeatureflags: + m.ClearFeatureflags() + return nil + case machine.FieldHubstate: + m.ClearHubstate() + return nil + case machine.FieldDatasources: + m.ClearDatasources() + return nil + } + return fmt.Errorf("unknown Machine nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *MachineMutation) ResetField(name string) error { + switch name { + case machine.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case machine.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case machine.FieldLastPush: + m.ResetLastPush() + return nil + case machine.FieldLastHeartbeat: + m.ResetLastHeartbeat() + return nil + case machine.FieldMachineId: + m.ResetMachineId() + return nil + case machine.FieldPassword: + m.ResetPassword() + return nil + case machine.FieldIpAddress: + m.ResetIpAddress() + return nil + case machine.FieldScenarios: + m.ResetScenarios() + return nil + case machine.FieldVersion: + m.ResetVersion() + return nil + case machine.FieldIsValidated: + m.ResetIsValidated() + return nil + case machine.FieldAuthType: + m.ResetAuthType() + return nil + case machine.FieldOsname: + m.ResetOsname() + return nil + case machine.FieldOsversion: + m.ResetOsversion() + return nil + case machine.FieldFeatureflags: + m.ResetFeatureflags() + return nil + case machine.FieldHubstate: + m.ResetHubstate() + return nil + case machine.FieldDatasources: + m.ResetDatasources() + return nil + } + return fmt.Errorf("unknown Machine field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *MachineMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.alerts != nil { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *MachineMutation) AddedIDs(name string) []ent.Value { + switch name { + case machine.EdgeAlerts: + ids := make([]ent.Value, 0, len(m.alerts)) + for id := range m.alerts { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *MachineMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedalerts != nil { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *MachineMutation) RemovedIDs(name string) []ent.Value { + switch name { + case machine.EdgeAlerts: + ids := make([]ent.Value, 0, len(m.removedalerts)) + for id := range m.removedalerts { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *MachineMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedalerts { + edges = append(edges, machine.EdgeAlerts) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *MachineMutation) EdgeCleared(name string) bool { + switch name { + case machine.EdgeAlerts: + return m.clearedalerts + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *MachineMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Machine unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *MachineMutation) ResetEdge(name string) error { + switch name { + case machine.EdgeAlerts: + m.ResetAlerts() + return nil + } + return fmt.Errorf("unknown Machine edge %s", name) +} + +// MetaMutation represents an operation that mutates the Meta nodes in the graph. +type MetaMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + key *string + value *string + clearedFields map[string]struct{} + owner *int + clearedowner bool + done bool + oldValue func(context.Context) (*Meta, error) + predicates []predicate.Meta +} + +var _ ent.Mutation = (*MetaMutation)(nil) + +// metaOption allows management of the mutation configuration using functional options. +type metaOption func(*MetaMutation) + +// newMetaMutation creates new mutation for the Meta entity. +func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { + m := &MetaMutation{ + config: c, + op: op, + typ: TypeMeta, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withMetaID sets the ID field of the mutation. +func withMetaID(id int) metaOption { + return func(m *MetaMutation) { + var ( + err error + once sync.Once + value *Meta + ) + m.oldValue = func(ctx context.Context) (*Meta, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Meta.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// OldPassword returns the old "password" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldPassword(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassword is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassword requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassword: %w", err) +// withMeta sets the old Meta of the mutation. +func withMeta(node *Meta) metaOption { + return func(m *MetaMutation) { + m.oldValue = func(context.Context) (*Meta, error) { + return node, nil + } + m.id = &node.ID } - return oldValue.Password, nil } -// ResetPassword resets all changes to the "password" field. -func (m *MachineMutation) ResetPassword() { - m.password = nil +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MetaMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// SetIpAddress sets the "ipAddress" field. -func (m *MachineMutation) SetIpAddress(s string) { - m.ipAddress = &s +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MetaMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// IpAddress returns the value of the "ipAddress" field in the mutation. -func (m *MachineMutation) IpAddress() (r string, exists bool) { - v := m.ipAddress - if v == nil { +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MetaMutation) ID() (id int, exists bool) { + if m.id == nil { return } - return *v, true + return *m.id, true } -// OldIpAddress returns the old "ipAddress" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldIpAddress(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIpAddress is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIpAddress requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldIpAddress: %w", err) +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Meta.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } - return oldValue.IpAddress, nil -} - -// ResetIpAddress resets all changes to the "ipAddress" field. -func (m *MachineMutation) ResetIpAddress() { - m.ipAddress = nil } -// SetScenarios sets the "scenarios" field. -func (m *MachineMutation) SetScenarios(s string) { - m.scenarios = &s +// SetCreatedAt sets the "created_at" field. +func (m *MetaMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// Scenarios returns the value of the "scenarios" field in the mutation. -func (m *MachineMutation) Scenarios() (r string, exists bool) { - v := m.scenarios +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *MetaMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldScenarios returns the old "scenarios" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldScenarios(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldScenarios is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldScenarios requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldScenarios: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.Scenarios, nil -} - -// ClearScenarios clears the value of the "scenarios" field. -func (m *MachineMutation) ClearScenarios() { - m.scenarios = nil - m.clearedFields[machine.FieldScenarios] = struct{}{} -} - -// ScenariosCleared returns if the "scenarios" field was cleared in this mutation. -func (m *MachineMutation) ScenariosCleared() bool { - _, ok := m.clearedFields[machine.FieldScenarios] - return ok + return oldValue.CreatedAt, nil } -// ResetScenarios resets all changes to the "scenarios" field. -func (m *MachineMutation) ResetScenarios() { - m.scenarios = nil - delete(m.clearedFields, machine.FieldScenarios) +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *MetaMutation) ResetCreatedAt() { + m.created_at = nil } -// SetVersion sets the "version" field. -func (m *MachineMutation) SetVersion(s string) { - m.version = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *MetaMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// Version returns the value of the "version" field in the mutation. -func (m *MachineMutation) Version() (r string, exists bool) { - v := m.version +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *MetaMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldVersion returns the old "version" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldVersion(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldVersion is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldVersion requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldVersion: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.Version, nil -} - -// ClearVersion clears the value of the "version" field. -func (m *MachineMutation) ClearVersion() { - m.version = nil - m.clearedFields[machine.FieldVersion] = struct{}{} -} - -// VersionCleared returns if the "version" field was cleared in this mutation. -func (m *MachineMutation) VersionCleared() bool { - _, ok := m.clearedFields[machine.FieldVersion] - return ok + return oldValue.UpdatedAt, nil } -// ResetVersion resets all changes to the "version" field. -func (m *MachineMutation) ResetVersion() { - m.version = nil - delete(m.clearedFields, machine.FieldVersion) +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *MetaMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetIsValidated sets the "isValidated" field. -func (m *MachineMutation) SetIsValidated(b bool) { - m.isValidated = &b +// SetKey sets the "key" field. +func (m *MetaMutation) SetKey(s string) { + m.key = &s } -// IsValidated returns the value of the "isValidated" field in the mutation. -func (m *MachineMutation) IsValidated() (r bool, exists bool) { - v := m.isValidated +// Key returns the value of the "key" field in the mutation. +func (m *MetaMutation) Key() (r string, exists bool) { + v := m.key if v == nil { return } return *v, true } -// OldIsValidated returns the old "isValidated" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldKey returns the old "key" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldIsValidated(ctx context.Context) (v bool, err error) { +func (m *MetaMutation) OldKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIsValidated is only allowed on UpdateOne operations") + return v, errors.New("OldKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIsValidated requires an ID field in the mutation") + return v, errors.New("OldKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIsValidated: %w", err) + return v, fmt.Errorf("querying old value for OldKey: %w", err) } - return oldValue.IsValidated, nil + return oldValue.Key, nil } -// ResetIsValidated resets all changes to the "isValidated" field. -func (m *MachineMutation) ResetIsValidated() { - m.isValidated = nil +// ResetKey resets all changes to the "key" field. +func (m *MetaMutation) ResetKey() { + m.key = nil } -// SetStatus sets the "status" field. -func (m *MachineMutation) SetStatus(s string) { - m.status = &s +// SetValue sets the "value" field. +func (m *MetaMutation) SetValue(s string) { + m.value = &s } -// Status returns the value of the "status" field in the mutation. -func (m *MachineMutation) Status() (r string, exists bool) { - v := m.status +// Value returns the value of the "value" field in the mutation. +func (m *MetaMutation) Value() (r string, exists bool) { + v := m.value if v == nil { return } return *v, true } -// OldStatus returns the old "status" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldValue returns the old "value" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldValue(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") + return v, errors.New("OldValue is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + return v, errors.New("OldValue requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + return v, fmt.Errorf("querying old value for OldValue: %w", err) } - return oldValue.Status, nil -} - -// ClearStatus clears the value of the "status" field. -func (m *MachineMutation) ClearStatus() { - m.status = nil - m.clearedFields[machine.FieldStatus] = struct{}{} -} - -// StatusCleared returns if the "status" field was cleared in this mutation. -func (m *MachineMutation) StatusCleared() bool { - _, ok := m.clearedFields[machine.FieldStatus] - return ok + return oldValue.Value, nil } -// ResetStatus resets all changes to the "status" field. -func (m *MachineMutation) ResetStatus() { - m.status = nil - delete(m.clearedFields, machine.FieldStatus) +// ResetValue resets all changes to the "value" field. +func (m *MetaMutation) ResetValue() { + m.value = nil } -// SetAuthType sets the "auth_type" field. -func (m *MachineMutation) SetAuthType(s string) { - m.auth_type = &s +// SetAlertMetas sets the "alert_metas" field. +func (m *MetaMutation) SetAlertMetas(i int) { + m.owner = &i } -// AuthType returns the value of the "auth_type" field in the mutation. -func (m *MachineMutation) AuthType() (r string, exists bool) { - v := m.auth_type +// AlertMetas returns the value of the "alert_metas" field in the mutation. +func (m *MetaMutation) AlertMetas() (r int, exists bool) { + v := m.owner if v == nil { return } return *v, true } -// OldAuthType returns the old "auth_type" field's value of the Machine entity. -// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// OldAlertMetas returns the old "alert_metas" field's value of the Meta entity. +// If the Meta object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) { +func (m *MetaMutation) OldAlertMetas(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + return v, errors.New("OldAlertMetas is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAuthType requires an ID field in the mutation") + return v, errors.New("OldAlertMetas requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + return v, fmt.Errorf("querying old value for OldAlertMetas: %w", err) } - return oldValue.AuthType, nil + return oldValue.AlertMetas, nil } -// ResetAuthType resets all changes to the "auth_type" field. -func (m *MachineMutation) ResetAuthType() { - m.auth_type = nil +// ClearAlertMetas clears the value of the "alert_metas" field. +func (m *MetaMutation) ClearAlertMetas() { + m.owner = nil + m.clearedFields[meta.FieldAlertMetas] = struct{}{} } -// AddAlertIDs adds the "alerts" edge to the Alert entity by ids. -func (m *MachineMutation) AddAlertIDs(ids ...int) { - if m.alerts == nil { - m.alerts = make(map[int]struct{}) - } - for i := range ids { - m.alerts[ids[i]] = struct{}{} - } +// AlertMetasCleared returns if the "alert_metas" field was cleared in this mutation. +func (m *MetaMutation) AlertMetasCleared() bool { + _, ok := m.clearedFields[meta.FieldAlertMetas] + return ok } -// ClearAlerts clears the "alerts" edge to the Alert entity. -func (m *MachineMutation) ClearAlerts() { - m.clearedalerts = true +// ResetAlertMetas resets all changes to the "alert_metas" field. +func (m *MetaMutation) ResetAlertMetas() { + m.owner = nil + delete(m.clearedFields, meta.FieldAlertMetas) } -// AlertsCleared reports if the "alerts" edge to the Alert entity was cleared. -func (m *MachineMutation) AlertsCleared() bool { - return m.clearedalerts +// SetOwnerID sets the "owner" edge to the Alert entity by id. +func (m *MetaMutation) SetOwnerID(id int) { + m.owner = &id } -// RemoveAlertIDs removes the "alerts" edge to the Alert entity by IDs. -func (m *MachineMutation) RemoveAlertIDs(ids ...int) { - if m.removedalerts == nil { - m.removedalerts = make(map[int]struct{}) - } - for i := range ids { - delete(m.alerts, ids[i]) - m.removedalerts[ids[i]] = struct{}{} - } +// ClearOwner clears the "owner" edge to the Alert entity. +func (m *MetaMutation) ClearOwner() { + m.clearedowner = true + m.clearedFields[meta.FieldAlertMetas] = struct{}{} } -// RemovedAlerts returns the removed IDs of the "alerts" edge to the Alert entity. -func (m *MachineMutation) RemovedAlertsIDs() (ids []int) { - for id := range m.removedalerts { - ids = append(ids, id) +// OwnerCleared reports if the "owner" edge to the Alert entity was cleared. +func (m *MetaMutation) OwnerCleared() bool { + return m.AlertMetasCleared() || m.clearedowner +} + +// OwnerID returns the "owner" edge ID in the mutation. +func (m *MetaMutation) OwnerID() (id int, exists bool) { + if m.owner != nil { + return *m.owner, true } return } -// AlertsIDs returns the "alerts" edge IDs in the mutation. -func (m *MachineMutation) AlertsIDs() (ids []int) { - for id := range m.alerts { - ids = append(ids, id) +// OwnerIDs returns the "owner" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// OwnerID instead. It exists only for internal usage by the builders. +func (m *MetaMutation) OwnerIDs() (ids []int) { + if id := m.owner; id != nil { + ids = append(ids, *id) } return } -// ResetAlerts resets all changes to the "alerts" edge. -func (m *MachineMutation) ResetAlerts() { - m.alerts = nil - m.clearedalerts = false - m.removedalerts = nil +// ResetOwner resets all changes to the "owner" edge. +func (m *MetaMutation) ResetOwner() { + m.owner = nil + m.clearedowner = false } -// Where appends a list predicates to the MachineMutation builder. -func (m *MachineMutation) Where(ps ...predicate.Machine) { +// Where appends a list predicates to the MetaMutation builder. +func (m *MetaMutation) Where(ps ...predicate.Meta) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MetaMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetaMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Meta, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. -func (m *MachineMutation) Op() Op { +func (m *MetaMutation) Op() Op { return m.op } -// Type returns the node type of this mutation (Machine). -func (m *MachineMutation) Type() string { +// SetOp allows setting the mutation operation. +func (m *MetaMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Meta). +func (m *MetaMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *MachineMutation) Fields() []string { - fields := make([]string, 0, 12) +func (m *MetaMutation) Fields() []string { + fields := make([]string, 0, 5) if m.created_at != nil { - fields = append(fields, machine.FieldCreatedAt) + fields = append(fields, meta.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, machine.FieldUpdatedAt) - } - if m.last_push != nil { - fields = append(fields, machine.FieldLastPush) - } - if m.last_heartbeat != nil { - fields = append(fields, machine.FieldLastHeartbeat) - } - if m.machineId != nil { - fields = append(fields, machine.FieldMachineId) - } - if m.password != nil { - fields = append(fields, machine.FieldPassword) - } - if m.ipAddress != nil { - fields = append(fields, machine.FieldIpAddress) - } - if m.scenarios != nil { - fields = append(fields, machine.FieldScenarios) - } - if m.version != nil { - fields = append(fields, machine.FieldVersion) + fields = append(fields, meta.FieldUpdatedAt) } - if m.isValidated != nil { - fields = append(fields, machine.FieldIsValidated) + if m.key != nil { + fields = append(fields, meta.FieldKey) } - if m.status != nil { - fields = append(fields, machine.FieldStatus) + if m.value != nil { + fields = append(fields, meta.FieldValue) } - if m.auth_type != nil { - fields = append(fields, machine.FieldAuthType) + if m.owner != nil { + fields = append(fields, meta.FieldAlertMetas) } return fields } @@ -6852,32 +8403,18 @@ func (m *MachineMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *MachineMutation) Field(name string) (ent.Value, bool) { +func (m *MetaMutation) Field(name string) (ent.Value, bool) { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: return m.CreatedAt() - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: return m.UpdatedAt() - case machine.FieldLastPush: - return m.LastPush() - case machine.FieldLastHeartbeat: - return m.LastHeartbeat() - case machine.FieldMachineId: - return m.MachineId() - case machine.FieldPassword: - return m.Password() - case machine.FieldIpAddress: - return m.IpAddress() - case machine.FieldScenarios: - return m.Scenarios() - case machine.FieldVersion: - return m.Version() - case machine.FieldIsValidated: - return m.IsValidated() - case machine.FieldStatus: - return m.Status() - case machine.FieldAuthType: - return m.AuthType() + case meta.FieldKey: + return m.Key() + case meta.FieldValue: + return m.Value() + case meta.FieldAlertMetas: + return m.AlertMetas() } return nil, false } @@ -6885,372 +8422,244 @@ func (m *MachineMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *MetaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: return m.OldCreatedAt(ctx) - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case machine.FieldLastPush: - return m.OldLastPush(ctx) - case machine.FieldLastHeartbeat: - return m.OldLastHeartbeat(ctx) - case machine.FieldMachineId: - return m.OldMachineId(ctx) - case machine.FieldPassword: - return m.OldPassword(ctx) - case machine.FieldIpAddress: - return m.OldIpAddress(ctx) - case machine.FieldScenarios: - return m.OldScenarios(ctx) - case machine.FieldVersion: - return m.OldVersion(ctx) - case machine.FieldIsValidated: - return m.OldIsValidated(ctx) - case machine.FieldStatus: - return m.OldStatus(ctx) - case machine.FieldAuthType: - return m.OldAuthType(ctx) + case meta.FieldKey: + return m.OldKey(ctx) + case meta.FieldValue: + return m.OldValue(ctx) + case meta.FieldAlertMetas: + return m.OldAlertMetas(ctx) } - return nil, fmt.Errorf("unknown Machine field %s", name) + return nil, fmt.Errorf("unknown Meta field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MachineMutation) SetField(name string, value ent.Value) error { +func (m *MetaMutation) SetField(name string, value ent.Value) error { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case machine.FieldLastPush: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetLastPush(v) - return nil - case machine.FieldLastHeartbeat: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetLastHeartbeat(v) - return nil - case machine.FieldMachineId: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMachineId(v) - return nil - case machine.FieldPassword: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPassword(v) - return nil - case machine.FieldIpAddress: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIpAddress(v) - return nil - case machine.FieldScenarios: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetScenarios(v) - return nil - case machine.FieldVersion: + case meta.FieldKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetVersion(v) - return nil - case machine.FieldIsValidated: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIsValidated(v) + m.SetKey(v) return nil - case machine.FieldStatus: + case meta.FieldValue: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetValue(v) return nil - case machine.FieldAuthType: - v, ok := value.(string) + case meta.FieldAlertMetas: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAuthType(v) + m.SetAlertMetas(v) return nil } - return fmt.Errorf("unknown Machine field %s", name) + return fmt.Errorf("unknown Meta field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *MachineMutation) AddedFields() []string { - return nil +func (m *MetaMutation) AddedFields() []string { + var fields []string + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *MachineMutation) AddedField(name string) (ent.Value, bool) { +func (m *MetaMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MachineMutation) AddField(name string, value ent.Value) error { +func (m *MetaMutation) AddField(name string, value ent.Value) error { switch name { } - return fmt.Errorf("unknown Machine numeric field %s", name) + return fmt.Errorf("unknown Meta numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *MachineMutation) ClearedFields() []string { +func (m *MetaMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(machine.FieldCreatedAt) { - fields = append(fields, machine.FieldCreatedAt) - } - if m.FieldCleared(machine.FieldUpdatedAt) { - fields = append(fields, machine.FieldUpdatedAt) - } - if m.FieldCleared(machine.FieldLastPush) { - fields = append(fields, machine.FieldLastPush) - } - if m.FieldCleared(machine.FieldLastHeartbeat) { - fields = append(fields, machine.FieldLastHeartbeat) - } - if m.FieldCleared(machine.FieldScenarios) { - fields = append(fields, machine.FieldScenarios) - } - if m.FieldCleared(machine.FieldVersion) { - fields = append(fields, machine.FieldVersion) - } - if m.FieldCleared(machine.FieldStatus) { - fields = append(fields, machine.FieldStatus) + if m.FieldCleared(meta.FieldAlertMetas) { + fields = append(fields, meta.FieldAlertMetas) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *MachineMutation) FieldCleared(name string) bool { +func (m *MetaMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *MachineMutation) ClearField(name string) error { +func (m *MetaMutation) ClearField(name string) error { switch name { - case machine.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case machine.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - case machine.FieldLastPush: - m.ClearLastPush() - return nil - case machine.FieldLastHeartbeat: - m.ClearLastHeartbeat() - return nil - case machine.FieldScenarios: - m.ClearScenarios() - return nil - case machine.FieldVersion: - m.ClearVersion() - return nil - case machine.FieldStatus: - m.ClearStatus() + case meta.FieldAlertMetas: + m.ClearAlertMetas() return nil } - return fmt.Errorf("unknown Machine nullable field %s", name) + return fmt.Errorf("unknown Meta nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *MachineMutation) ResetField(name string) error { +func (m *MetaMutation) ResetField(name string) error { switch name { - case machine.FieldCreatedAt: + case meta.FieldCreatedAt: m.ResetCreatedAt() return nil - case machine.FieldUpdatedAt: + case meta.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case machine.FieldLastPush: - m.ResetLastPush() - return nil - case machine.FieldLastHeartbeat: - m.ResetLastHeartbeat() - return nil - case machine.FieldMachineId: - m.ResetMachineId() - return nil - case machine.FieldPassword: - m.ResetPassword() - return nil - case machine.FieldIpAddress: - m.ResetIpAddress() - return nil - case machine.FieldScenarios: - m.ResetScenarios() - return nil - case machine.FieldVersion: - m.ResetVersion() - return nil - case machine.FieldIsValidated: - m.ResetIsValidated() + case meta.FieldKey: + m.ResetKey() return nil - case machine.FieldStatus: - m.ResetStatus() + case meta.FieldValue: + m.ResetValue() return nil - case machine.FieldAuthType: - m.ResetAuthType() + case meta.FieldAlertMetas: + m.ResetAlertMetas() return nil } - return fmt.Errorf("unknown Machine field %s", name) + return fmt.Errorf("unknown Meta field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *MachineMutation) AddedEdges() []string { +func (m *MetaMutation) AddedEdges() []string { edges := make([]string, 0, 1) - if m.alerts != nil { - edges = append(edges, machine.EdgeAlerts) + if m.owner != nil { + edges = append(edges, meta.EdgeOwner) } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *MachineMutation) AddedIDs(name string) []ent.Value { +func (m *MetaMutation) AddedIDs(name string) []ent.Value { switch name { - case machine.EdgeAlerts: - ids := make([]ent.Value, 0, len(m.alerts)) - for id := range m.alerts { - ids = append(ids, id) + case meta.EdgeOwner: + if id := m.owner; id != nil { + return []ent.Value{*id} } - return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *MachineMutation) RemovedEdges() []string { +func (m *MetaMutation) RemovedEdges() []string { edges := make([]string, 0, 1) - if m.removedalerts != nil { - edges = append(edges, machine.EdgeAlerts) - } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *MachineMutation) RemovedIDs(name string) []ent.Value { - switch name { - case machine.EdgeAlerts: - ids := make([]ent.Value, 0, len(m.removedalerts)) - for id := range m.removedalerts { - ids = append(ids, id) - } - return ids - } +func (m *MetaMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *MachineMutation) ClearedEdges() []string { +func (m *MetaMutation) ClearedEdges() []string { edges := make([]string, 0, 1) - if m.clearedalerts { - edges = append(edges, machine.EdgeAlerts) + if m.clearedowner { + edges = append(edges, meta.EdgeOwner) } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *MachineMutation) EdgeCleared(name string) bool { +func (m *MetaMutation) EdgeCleared(name string) bool { switch name { - case machine.EdgeAlerts: - return m.clearedalerts + case meta.EdgeOwner: + return m.clearedowner } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *MachineMutation) ClearEdge(name string) error { +func (m *MetaMutation) ClearEdge(name string) error { switch name { + case meta.EdgeOwner: + m.ClearOwner() + return nil } - return fmt.Errorf("unknown Machine unique edge %s", name) + return fmt.Errorf("unknown Meta unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *MachineMutation) ResetEdge(name string) error { +func (m *MetaMutation) ResetEdge(name string) error { switch name { - case machine.EdgeAlerts: - m.ResetAlerts() + case meta.EdgeOwner: + m.ResetOwner() return nil } - return fmt.Errorf("unknown Machine edge %s", name) + return fmt.Errorf("unknown Meta edge %s", name) } -// MetaMutation represents an operation that mutates the Meta nodes in the graph. -type MetaMutation struct { +// MetricMutation represents an operation that mutates the Metric nodes in the graph. +type MetricMutation struct { config - op Op - typ string - id *int - created_at *time.Time - updated_at *time.Time - key *string - value *string - clearedFields map[string]struct{} - owner *int - clearedowner bool - done bool - oldValue func(context.Context) (*Meta, error) - predicates []predicate.Meta + op Op + typ string + id *int + generated_type *metric.GeneratedType + generated_by *string + received_at *time.Time + pushed_at *time.Time + payload *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Metric, error) + predicates []predicate.Metric } -var _ ent.Mutation = (*MetaMutation)(nil) +var _ ent.Mutation = (*MetricMutation)(nil) -// metaOption allows management of the mutation configuration using functional options. -type metaOption func(*MetaMutation) +// metricOption allows management of the mutation configuration using functional options. +type metricOption func(*MetricMutation) -// newMetaMutation creates new mutation for the Meta entity. -func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { - m := &MetaMutation{ +// newMetricMutation creates new mutation for the Metric entity. +func newMetricMutation(c config, op Op, opts ...metricOption) *MetricMutation { + m := &MetricMutation{ config: c, op: op, - typ: TypeMeta, + typ: TypeMetric, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -7259,20 +8668,20 @@ func newMetaMutation(c config, op Op, opts ...metaOption) *MetaMutation { return m } -// withMetaID sets the ID field of the mutation. -func withMetaID(id int) metaOption { - return func(m *MetaMutation) { +// withMetricID sets the ID field of the mutation. +func withMetricID(id int) metricOption { + return func(m *MetricMutation) { var ( err error once sync.Once - value *Meta + value *Metric ) - m.oldValue = func(ctx context.Context) (*Meta, error) { + m.oldValue = func(ctx context.Context) (*Metric, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Meta.Get(ctx, id) + value, err = m.Client().Metric.Get(ctx, id) } }) return value, err @@ -7281,10 +8690,10 @@ func withMetaID(id int) metaOption { } } -// withMeta sets the old Meta of the mutation. -func withMeta(node *Meta) metaOption { - return func(m *MetaMutation) { - m.oldValue = func(context.Context) (*Meta, error) { +// withMetric sets the old Metric of the mutation. +func withMetric(node *Metric) metricOption { + return func(m *MetricMutation) { + m.oldValue = func(context.Context) (*Metric, error) { return node, nil } m.id = &node.ID @@ -7293,7 +8702,7 @@ func withMeta(node *Meta) metaOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m MetaMutation) Client() *Client { +func (m MetricMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -7301,7 +8710,7 @@ func (m MetaMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m MetaMutation) Tx() (*Tx, error) { +func (m MetricMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -7312,7 +8721,7 @@ func (m MetaMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *MetaMutation) ID() (id int, exists bool) { +func (m *MetricMutation) ID() (id int, exists bool) { if m.id == nil { return } @@ -7323,7 +8732,7 @@ func (m *MetaMutation) ID() (id int, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { +func (m *MetricMutation) IDs(ctx context.Context) ([]int, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -7332,304 +8741,254 @@ func (m *MetaMutation) IDs(ctx context.Context) ([]int, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Meta.Query().Where(m.predicates...).IDs(ctx) + return m.Client().Metric.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetCreatedAt sets the "created_at" field. -func (m *MetaMutation) SetCreatedAt(t time.Time) { - m.created_at = &t -} - -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *MetaMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at - if v == nil { - return - } - return *v, true -} - -// OldCreatedAt returns the old "created_at" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldCreatedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) - } - return oldValue.CreatedAt, nil -} - -// ClearCreatedAt clears the value of the "created_at" field. -func (m *MetaMutation) ClearCreatedAt() { - m.created_at = nil - m.clearedFields[meta.FieldCreatedAt] = struct{}{} -} - -// CreatedAtCleared returns if the "created_at" field was cleared in this mutation. -func (m *MetaMutation) CreatedAtCleared() bool { - _, ok := m.clearedFields[meta.FieldCreatedAt] - return ok -} - -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *MetaMutation) ResetCreatedAt() { - m.created_at = nil - delete(m.clearedFields, meta.FieldCreatedAt) -} - -// SetUpdatedAt sets the "updated_at" field. -func (m *MetaMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t -} - -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *MetaMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// SetGeneratedType sets the "generated_type" field. +func (m *MetricMutation) SetGeneratedType(mt metric.GeneratedType) { + m.generated_type = &mt +} + +// GeneratedType returns the value of the "generated_type" field in the mutation. +func (m *MetricMutation) GeneratedType() (r metric.GeneratedType, exists bool) { + v := m.generated_type if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldGeneratedType returns the old "generated_type" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldUpdatedAt(ctx context.Context) (v *time.Time, err error) { +func (m *MetricMutation) OldGeneratedType(ctx context.Context) (v metric.GeneratedType, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldGeneratedType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldGeneratedType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldGeneratedType: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.GeneratedType, nil } -// ClearUpdatedAt clears the value of the "updated_at" field. -func (m *MetaMutation) ClearUpdatedAt() { - m.updated_at = nil - m.clearedFields[meta.FieldUpdatedAt] = struct{}{} -} - -// UpdatedAtCleared returns if the "updated_at" field was cleared in this mutation. -func (m *MetaMutation) UpdatedAtCleared() bool { - _, ok := m.clearedFields[meta.FieldUpdatedAt] - return ok -} - -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *MetaMutation) ResetUpdatedAt() { - m.updated_at = nil - delete(m.clearedFields, meta.FieldUpdatedAt) +// ResetGeneratedType resets all changes to the "generated_type" field. +func (m *MetricMutation) ResetGeneratedType() { + m.generated_type = nil } -// SetKey sets the "key" field. -func (m *MetaMutation) SetKey(s string) { - m.key = &s +// SetGeneratedBy sets the "generated_by" field. +func (m *MetricMutation) SetGeneratedBy(s string) { + m.generated_by = &s } -// Key returns the value of the "key" field in the mutation. -func (m *MetaMutation) Key() (r string, exists bool) { - v := m.key +// GeneratedBy returns the value of the "generated_by" field in the mutation. +func (m *MetricMutation) GeneratedBy() (r string, exists bool) { + v := m.generated_by if v == nil { return } return *v, true } -// OldKey returns the old "key" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldGeneratedBy returns the old "generated_by" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldKey(ctx context.Context) (v string, err error) { +func (m *MetricMutation) OldGeneratedBy(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldKey is only allowed on UpdateOne operations") + return v, errors.New("OldGeneratedBy is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldKey requires an ID field in the mutation") + return v, errors.New("OldGeneratedBy requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldKey: %w", err) + return v, fmt.Errorf("querying old value for OldGeneratedBy: %w", err) } - return oldValue.Key, nil + return oldValue.GeneratedBy, nil } -// ResetKey resets all changes to the "key" field. -func (m *MetaMutation) ResetKey() { - m.key = nil +// ResetGeneratedBy resets all changes to the "generated_by" field. +func (m *MetricMutation) ResetGeneratedBy() { + m.generated_by = nil } -// SetValue sets the "value" field. -func (m *MetaMutation) SetValue(s string) { - m.value = &s +// SetReceivedAt sets the "received_at" field. +func (m *MetricMutation) SetReceivedAt(t time.Time) { + m.received_at = &t } -// Value returns the value of the "value" field in the mutation. -func (m *MetaMutation) Value() (r string, exists bool) { - v := m.value +// ReceivedAt returns the value of the "received_at" field in the mutation. +func (m *MetricMutation) ReceivedAt() (r time.Time, exists bool) { + v := m.received_at if v == nil { return } return *v, true } -// OldValue returns the old "value" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldReceivedAt returns the old "received_at" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldValue(ctx context.Context) (v string, err error) { +func (m *MetricMutation) OldReceivedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldValue is only allowed on UpdateOne operations") + return v, errors.New("OldReceivedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldValue requires an ID field in the mutation") + return v, errors.New("OldReceivedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldValue: %w", err) + return v, fmt.Errorf("querying old value for OldReceivedAt: %w", err) } - return oldValue.Value, nil + return oldValue.ReceivedAt, nil } -// ResetValue resets all changes to the "value" field. -func (m *MetaMutation) ResetValue() { - m.value = nil +// ResetReceivedAt resets all changes to the "received_at" field. +func (m *MetricMutation) ResetReceivedAt() { + m.received_at = nil } -// SetAlertMetas sets the "alert_metas" field. -func (m *MetaMutation) SetAlertMetas(i int) { - m.owner = &i +// SetPushedAt sets the "pushed_at" field. +func (m *MetricMutation) SetPushedAt(t time.Time) { + m.pushed_at = &t } -// AlertMetas returns the value of the "alert_metas" field in the mutation. -func (m *MetaMutation) AlertMetas() (r int, exists bool) { - v := m.owner +// PushedAt returns the value of the "pushed_at" field in the mutation. +func (m *MetricMutation) PushedAt() (r time.Time, exists bool) { + v := m.pushed_at if v == nil { return } return *v, true } -// OldAlertMetas returns the old "alert_metas" field's value of the Meta entity. -// If the Meta object wasn't provided to the builder, the object is fetched from the database. +// OldPushedAt returns the old "pushed_at" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *MetaMutation) OldAlertMetas(ctx context.Context) (v int, err error) { +func (m *MetricMutation) OldPushedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAlertMetas is only allowed on UpdateOne operations") + return v, errors.New("OldPushedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAlertMetas requires an ID field in the mutation") + return v, errors.New("OldPushedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAlertMetas: %w", err) + return v, fmt.Errorf("querying old value for OldPushedAt: %w", err) } - return oldValue.AlertMetas, nil + return oldValue.PushedAt, nil } -// ClearAlertMetas clears the value of the "alert_metas" field. -func (m *MetaMutation) ClearAlertMetas() { - m.owner = nil - m.clearedFields[meta.FieldAlertMetas] = struct{}{} +// ClearPushedAt clears the value of the "pushed_at" field. +func (m *MetricMutation) ClearPushedAt() { + m.pushed_at = nil + m.clearedFields[metric.FieldPushedAt] = struct{}{} } -// AlertMetasCleared returns if the "alert_metas" field was cleared in this mutation. -func (m *MetaMutation) AlertMetasCleared() bool { - _, ok := m.clearedFields[meta.FieldAlertMetas] +// PushedAtCleared returns if the "pushed_at" field was cleared in this mutation. +func (m *MetricMutation) PushedAtCleared() bool { + _, ok := m.clearedFields[metric.FieldPushedAt] return ok } -// ResetAlertMetas resets all changes to the "alert_metas" field. -func (m *MetaMutation) ResetAlertMetas() { - m.owner = nil - delete(m.clearedFields, meta.FieldAlertMetas) -} - -// SetOwnerID sets the "owner" edge to the Alert entity by id. -func (m *MetaMutation) SetOwnerID(id int) { - m.owner = &id -} - -// ClearOwner clears the "owner" edge to the Alert entity. -func (m *MetaMutation) ClearOwner() { - m.clearedowner = true +// ResetPushedAt resets all changes to the "pushed_at" field. +func (m *MetricMutation) ResetPushedAt() { + m.pushed_at = nil + delete(m.clearedFields, metric.FieldPushedAt) } -// OwnerCleared reports if the "owner" edge to the Alert entity was cleared. -func (m *MetaMutation) OwnerCleared() bool { - return m.AlertMetasCleared() || m.clearedowner +// SetPayload sets the "payload" field. +func (m *MetricMutation) SetPayload(s string) { + m.payload = &s } -// OwnerID returns the "owner" edge ID in the mutation. -func (m *MetaMutation) OwnerID() (id int, exists bool) { - if m.owner != nil { - return *m.owner, true +// Payload returns the value of the "payload" field in the mutation. +func (m *MetricMutation) Payload() (r string, exists bool) { + v := m.payload + if v == nil { + return } - return + return *v, true } -// OwnerIDs returns the "owner" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// OwnerID instead. It exists only for internal usage by the builders. -func (m *MetaMutation) OwnerIDs() (ids []int) { - if id := m.owner; id != nil { - ids = append(ids, *id) +// OldPayload returns the old "payload" field's value of the Metric entity. +// If the Metric object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetricMutation) OldPayload(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayload is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayload requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayload: %w", err) + } + return oldValue.Payload, nil } -// ResetOwner resets all changes to the "owner" edge. -func (m *MetaMutation) ResetOwner() { - m.owner = nil - m.clearedowner = false +// ResetPayload resets all changes to the "payload" field. +func (m *MetricMutation) ResetPayload() { + m.payload = nil } -// Where appends a list predicates to the MetaMutation builder. -func (m *MetaMutation) Where(ps ...predicate.Meta) { +// Where appends a list predicates to the MetricMutation builder. +func (m *MetricMutation) Where(ps ...predicate.Metric) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MetricMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetricMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Metric, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. -func (m *MetaMutation) Op() Op { +func (m *MetricMutation) Op() Op { return m.op } -// Type returns the node type of this mutation (Meta). -func (m *MetaMutation) Type() string { +// SetOp allows setting the mutation operation. +func (m *MetricMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Metric). +func (m *MetricMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *MetaMutation) Fields() []string { +func (m *MetricMutation) Fields() []string { fields := make([]string, 0, 5) - if m.created_at != nil { - fields = append(fields, meta.FieldCreatedAt) + if m.generated_type != nil { + fields = append(fields, metric.FieldGeneratedType) } - if m.updated_at != nil { - fields = append(fields, meta.FieldUpdatedAt) + if m.generated_by != nil { + fields = append(fields, metric.FieldGeneratedBy) } - if m.key != nil { - fields = append(fields, meta.FieldKey) + if m.received_at != nil { + fields = append(fields, metric.FieldReceivedAt) } - if m.value != nil { - fields = append(fields, meta.FieldValue) + if m.pushed_at != nil { + fields = append(fields, metric.FieldPushedAt) } - if m.owner != nil { - fields = append(fields, meta.FieldAlertMetas) + if m.payload != nil { + fields = append(fields, metric.FieldPayload) } return fields } @@ -7637,18 +8996,18 @@ func (m *MetaMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *MetaMutation) Field(name string) (ent.Value, bool) { +func (m *MetricMutation) Field(name string) (ent.Value, bool) { switch name { - case meta.FieldCreatedAt: - return m.CreatedAt() - case meta.FieldUpdatedAt: - return m.UpdatedAt() - case meta.FieldKey: - return m.Key() - case meta.FieldValue: - return m.Value() - case meta.FieldAlertMetas: - return m.AlertMetas() + case metric.FieldGeneratedType: + return m.GeneratedType() + case metric.FieldGeneratedBy: + return m.GeneratedBy() + case metric.FieldReceivedAt: + return m.ReceivedAt() + case metric.FieldPushedAt: + return m.PushedAt() + case metric.FieldPayload: + return m.Payload() } return nil, false } @@ -7656,224 +9015,183 @@ func (m *MetaMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *MetaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *MetricMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case meta.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case meta.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case meta.FieldKey: - return m.OldKey(ctx) - case meta.FieldValue: - return m.OldValue(ctx) - case meta.FieldAlertMetas: - return m.OldAlertMetas(ctx) + case metric.FieldGeneratedType: + return m.OldGeneratedType(ctx) + case metric.FieldGeneratedBy: + return m.OldGeneratedBy(ctx) + case metric.FieldReceivedAt: + return m.OldReceivedAt(ctx) + case metric.FieldPushedAt: + return m.OldPushedAt(ctx) + case metric.FieldPayload: + return m.OldPayload(ctx) } - return nil, fmt.Errorf("unknown Meta field %s", name) + return nil, fmt.Errorf("unknown Metric field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MetaMutation) SetField(name string, value ent.Value) error { +func (m *MetricMutation) SetField(name string, value ent.Value) error { switch name { - case meta.FieldCreatedAt: - v, ok := value.(time.Time) + case metric.FieldGeneratedType: + v, ok := value.(metric.GeneratedType) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetGeneratedType(v) return nil - case meta.FieldUpdatedAt: - v, ok := value.(time.Time) + case metric.FieldGeneratedBy: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetGeneratedBy(v) return nil - case meta.FieldKey: - v, ok := value.(string) + case metric.FieldReceivedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetKey(v) + m.SetReceivedAt(v) return nil - case meta.FieldValue: - v, ok := value.(string) + case metric.FieldPushedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetValue(v) + m.SetPushedAt(v) return nil - case meta.FieldAlertMetas: - v, ok := value.(int) + case metric.FieldPayload: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAlertMetas(v) + m.SetPayload(v) return nil } - return fmt.Errorf("unknown Meta field %s", name) + return fmt.Errorf("unknown Metric field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *MetaMutation) AddedFields() []string { - var fields []string - return fields +func (m *MetricMutation) AddedFields() []string { + return nil } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *MetaMutation) AddedField(name string) (ent.Value, bool) { - switch name { - } +func (m *MetricMutation) AddedField(name string) (ent.Value, bool) { return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *MetaMutation) AddField(name string, value ent.Value) error { +func (m *MetricMutation) AddField(name string, value ent.Value) error { switch name { } - return fmt.Errorf("unknown Meta numeric field %s", name) + return fmt.Errorf("unknown Metric numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *MetaMutation) ClearedFields() []string { +func (m *MetricMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(meta.FieldCreatedAt) { - fields = append(fields, meta.FieldCreatedAt) - } - if m.FieldCleared(meta.FieldUpdatedAt) { - fields = append(fields, meta.FieldUpdatedAt) - } - if m.FieldCleared(meta.FieldAlertMetas) { - fields = append(fields, meta.FieldAlertMetas) + if m.FieldCleared(metric.FieldPushedAt) { + fields = append(fields, metric.FieldPushedAt) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *MetaMutation) FieldCleared(name string) bool { +func (m *MetricMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *MetaMutation) ClearField(name string) error { +func (m *MetricMutation) ClearField(name string) error { switch name { - case meta.FieldCreatedAt: - m.ClearCreatedAt() - return nil - case meta.FieldUpdatedAt: - m.ClearUpdatedAt() - return nil - case meta.FieldAlertMetas: - m.ClearAlertMetas() + case metric.FieldPushedAt: + m.ClearPushedAt() return nil } - return fmt.Errorf("unknown Meta nullable field %s", name) + return fmt.Errorf("unknown Metric nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *MetaMutation) ResetField(name string) error { +func (m *MetricMutation) ResetField(name string) error { switch name { - case meta.FieldCreatedAt: - m.ResetCreatedAt() + case metric.FieldGeneratedType: + m.ResetGeneratedType() return nil - case meta.FieldUpdatedAt: - m.ResetUpdatedAt() + case metric.FieldGeneratedBy: + m.ResetGeneratedBy() return nil - case meta.FieldKey: - m.ResetKey() + case metric.FieldReceivedAt: + m.ResetReceivedAt() return nil - case meta.FieldValue: - m.ResetValue() + case metric.FieldPushedAt: + m.ResetPushedAt() return nil - case meta.FieldAlertMetas: - m.ResetAlertMetas() + case metric.FieldPayload: + m.ResetPayload() return nil } - return fmt.Errorf("unknown Meta field %s", name) + return fmt.Errorf("unknown Metric field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *MetaMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.owner != nil { - edges = append(edges, meta.EdgeOwner) - } +func (m *MetricMutation) AddedEdges() []string { + edges := make([]string, 0, 0) return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *MetaMutation) AddedIDs(name string) []ent.Value { - switch name { - case meta.EdgeOwner: - if id := m.owner; id != nil { - return []ent.Value{*id} - } - } +func (m *MetricMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *MetaMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) +func (m *MetricMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *MetaMutation) RemovedIDs(name string) []ent.Value { +func (m *MetricMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *MetaMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.clearedowner { - edges = append(edges, meta.EdgeOwner) - } +func (m *MetricMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *MetaMutation) EdgeCleared(name string) bool { - switch name { - case meta.EdgeOwner: - return m.clearedowner - } +func (m *MetricMutation) EdgeCleared(name string) bool { return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *MetaMutation) ClearEdge(name string) error { - switch name { - case meta.EdgeOwner: - m.ClearOwner() - return nil - } - return fmt.Errorf("unknown Meta unique edge %s", name) +func (m *MetricMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Metric unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *MetaMutation) ResetEdge(name string) error { - switch name { - case meta.EdgeOwner: - m.ResetOwner() - return nil - } - return fmt.Errorf("unknown Meta edge %s", name) +func (m *MetricMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Metric edge %s", name) } diff --git a/pkg/database/ent/predicate/predicate.go b/pkg/database/ent/predicate/predicate.go index e95abcec343..8ad03e2fc48 100644 --- a/pkg/database/ent/predicate/predicate.go +++ b/pkg/database/ent/predicate/predicate.go @@ -21,8 +21,14 @@ type Decision func(*sql.Selector) // Event is the predicate function for event builders. type Event func(*sql.Selector) +// Lock is the predicate function for lock builders. +type Lock func(*sql.Selector) + // Machine is the predicate function for machine builders. type Machine func(*sql.Selector) // Meta is the predicate function for meta builders. type Meta func(*sql.Selector) + +// Metric is the predicate function for metric builders. +type Metric func(*sql.Selector) diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index bceea37b3a7..15413490633 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" @@ -25,8 +26,6 @@ func init() { alertDescCreatedAt := alertFields[0].Descriptor() // alert.DefaultCreatedAt holds the default value on creation for the created_at field. alert.DefaultCreatedAt = alertDescCreatedAt.Default.(func() time.Time) - // alert.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - alert.UpdateDefaultCreatedAt = alertDescCreatedAt.UpdateDefault.(func() time.Time) // alertDescUpdatedAt is the schema descriptor for updated_at field. alertDescUpdatedAt := alertFields[1].Descriptor() // alert.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -63,8 +62,6 @@ func init() { bouncerDescCreatedAt := bouncerFields[0].Descriptor() // bouncer.DefaultCreatedAt holds the default value on creation for the created_at field. bouncer.DefaultCreatedAt = bouncerDescCreatedAt.Default.(func() time.Time) - // bouncer.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - bouncer.UpdateDefaultCreatedAt = bouncerDescCreatedAt.UpdateDefault.(func() time.Time) // bouncerDescUpdatedAt is the schema descriptor for updated_at field. bouncerDescUpdatedAt := bouncerFields[1].Descriptor() // bouncer.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -75,16 +72,8 @@ func init() { bouncerDescIPAddress := bouncerFields[5].Descriptor() // bouncer.DefaultIPAddress holds the default value on creation for the ip_address field. bouncer.DefaultIPAddress = bouncerDescIPAddress.Default.(string) - // bouncerDescUntil is the schema descriptor for until field. - bouncerDescUntil := bouncerFields[8].Descriptor() - // bouncer.DefaultUntil holds the default value on creation for the until field. - bouncer.DefaultUntil = bouncerDescUntil.Default.(func() time.Time) - // bouncerDescLastPull is the schema descriptor for last_pull field. - bouncerDescLastPull := bouncerFields[9].Descriptor() - // bouncer.DefaultLastPull holds the default value on creation for the last_pull field. - bouncer.DefaultLastPull = bouncerDescLastPull.Default.(func() time.Time) // bouncerDescAuthType is the schema descriptor for auth_type field. - bouncerDescAuthType := bouncerFields[10].Descriptor() + bouncerDescAuthType := bouncerFields[9].Descriptor() // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) configitemFields := schema.ConfigItem{}.Fields() @@ -93,8 +82,6 @@ func init() { configitemDescCreatedAt := configitemFields[0].Descriptor() // configitem.DefaultCreatedAt holds the default value on creation for the created_at field. configitem.DefaultCreatedAt = configitemDescCreatedAt.Default.(func() time.Time) - // configitem.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - configitem.UpdateDefaultCreatedAt = configitemDescCreatedAt.UpdateDefault.(func() time.Time) // configitemDescUpdatedAt is the schema descriptor for updated_at field. configitemDescUpdatedAt := configitemFields[1].Descriptor() // configitem.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -107,8 +94,6 @@ func init() { decisionDescCreatedAt := decisionFields[0].Descriptor() // decision.DefaultCreatedAt holds the default value on creation for the created_at field. decision.DefaultCreatedAt = decisionDescCreatedAt.Default.(func() time.Time) - // decision.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - decision.UpdateDefaultCreatedAt = decisionDescCreatedAt.UpdateDefault.(func() time.Time) // decisionDescUpdatedAt is the schema descriptor for updated_at field. decisionDescUpdatedAt := decisionFields[1].Descriptor() // decision.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -125,8 +110,6 @@ func init() { eventDescCreatedAt := eventFields[0].Descriptor() // event.DefaultCreatedAt holds the default value on creation for the created_at field. event.DefaultCreatedAt = eventDescCreatedAt.Default.(func() time.Time) - // event.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - event.UpdateDefaultCreatedAt = eventDescCreatedAt.UpdateDefault.(func() time.Time) // eventDescUpdatedAt is the schema descriptor for updated_at field. eventDescUpdatedAt := eventFields[1].Descriptor() // event.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -137,14 +120,18 @@ func init() { eventDescSerialized := eventFields[3].Descriptor() // event.SerializedValidator is a validator for the "serialized" field. It is called by the builders before save. event.SerializedValidator = eventDescSerialized.Validators[0].(func(string) error) + lockFields := schema.Lock{}.Fields() + _ = lockFields + // lockDescCreatedAt is the schema descriptor for created_at field. + lockDescCreatedAt := lockFields[1].Descriptor() + // lock.DefaultCreatedAt holds the default value on creation for the created_at field. + lock.DefaultCreatedAt = lockDescCreatedAt.Default.(func() time.Time) machineFields := schema.Machine{}.Fields() _ = machineFields // machineDescCreatedAt is the schema descriptor for created_at field. machineDescCreatedAt := machineFields[0].Descriptor() // machine.DefaultCreatedAt holds the default value on creation for the created_at field. machine.DefaultCreatedAt = machineDescCreatedAt.Default.(func() time.Time) - // machine.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - machine.UpdateDefaultCreatedAt = machineDescCreatedAt.UpdateDefault.(func() time.Time) // machineDescUpdatedAt is the schema descriptor for updated_at field. machineDescUpdatedAt := machineFields[1].Descriptor() // machine.DefaultUpdatedAt holds the default value on creation for the updated_at field. @@ -155,14 +142,6 @@ func init() { machineDescLastPush := machineFields[2].Descriptor() // machine.DefaultLastPush holds the default value on creation for the last_push field. machine.DefaultLastPush = machineDescLastPush.Default.(func() time.Time) - // machine.UpdateDefaultLastPush holds the default value on update for the last_push field. - machine.UpdateDefaultLastPush = machineDescLastPush.UpdateDefault.(func() time.Time) - // machineDescLastHeartbeat is the schema descriptor for last_heartbeat field. - machineDescLastHeartbeat := machineFields[3].Descriptor() - // machine.DefaultLastHeartbeat holds the default value on creation for the last_heartbeat field. - machine.DefaultLastHeartbeat = machineDescLastHeartbeat.Default.(func() time.Time) - // machine.UpdateDefaultLastHeartbeat holds the default value on update for the last_heartbeat field. - machine.UpdateDefaultLastHeartbeat = machineDescLastHeartbeat.UpdateDefault.(func() time.Time) // machineDescScenarios is the schema descriptor for scenarios field. machineDescScenarios := machineFields[7].Descriptor() // machine.ScenariosValidator is a validator for the "scenarios" field. It is called by the builders before save. @@ -172,7 +151,7 @@ func init() { // machine.DefaultIsValidated holds the default value on creation for the isValidated field. machine.DefaultIsValidated = machineDescIsValidated.Default.(bool) // machineDescAuthType is the schema descriptor for auth_type field. - machineDescAuthType := machineFields[11].Descriptor() + machineDescAuthType := machineFields[10].Descriptor() // machine.DefaultAuthType holds the default value on creation for the auth_type field. machine.DefaultAuthType = machineDescAuthType.Default.(string) metaFields := schema.Meta{}.Fields() @@ -181,8 +160,6 @@ func init() { metaDescCreatedAt := metaFields[0].Descriptor() // meta.DefaultCreatedAt holds the default value on creation for the created_at field. meta.DefaultCreatedAt = metaDescCreatedAt.Default.(func() time.Time) - // meta.UpdateDefaultCreatedAt holds the default value on update for the created_at field. - meta.UpdateDefaultCreatedAt = metaDescCreatedAt.UpdateDefault.(func() time.Time) // metaDescUpdatedAt is the schema descriptor for updated_at field. metaDescUpdatedAt := metaFields[1].Descriptor() // meta.DefaultUpdatedAt holds the default value on creation for the updated_at field. diff --git a/pkg/database/ent/runtime/runtime.go b/pkg/database/ent/runtime/runtime.go index e64f7bd7554..9cb9d96258a 100644 --- a/pkg/database/ent/runtime/runtime.go +++ b/pkg/database/ent/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/crowdsecurity/crowdsec/pkg/database/ent/runtime.go const ( - Version = "v0.11.3" // Version of ent codegen. - Sum = "h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc=" // Sum of ent codegen. + Version = "v0.13.1" // Version of ent codegen. + Sum = "h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE=" // Sum of ent codegen. ) diff --git a/pkg/database/ent/schema/alert.go b/pkg/database/ent/schema/alert.go index f2df9d7f09c..87ace24aa84 100644 --- a/pkg/database/ent/schema/alert.go +++ b/pkg/database/ent/schema/alert.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,38 +20,39 @@ func (Alert) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("scenario"), - field.String("bucketId").Default("").Optional(), - field.String("message").Default("").Optional(), - field.Int32("eventsCount").Default(0).Optional(), - field.Time("startedAt").Default(types.UtcNow).Optional(), - field.Time("stoppedAt").Default(types.UtcNow).Optional(), + UpdateDefault(types.UtcNow), + field.String("scenario").Immutable(), + field.String("bucketId").Default("").Optional().Immutable(), + field.String("message").Default("").Optional().Immutable(), + field.Int32("eventsCount").Default(0).Optional().Immutable(), + field.Time("startedAt").Default(types.UtcNow).Optional().Immutable(), + field.Time("stoppedAt").Default(types.UtcNow).Optional().Immutable(), field.String("sourceIp"). - Optional(), + Optional().Immutable(), field.String("sourceRange"). - Optional(), + Optional().Immutable(), field.String("sourceAsNumber"). - Optional(), + Optional().Immutable(), field.String("sourceAsName"). - Optional(), + Optional().Immutable(), field.String("sourceCountry"). - Optional(), + Optional().Immutable(), field.Float32("sourceLatitude"). - Optional(), + Optional().Immutable(), field.Float32("sourceLongitude"). - Optional(), - field.String("sourceScope").Optional(), - field.String("sourceValue").Optional(), - field.Int32("capacity").Optional(), - field.String("leakSpeed").Optional(), - field.String("scenarioVersion").Optional(), - field.String("scenarioHash").Optional(), - field.Bool("simulated").Default(false), - field.String("uuid").Optional(), //this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each alert + Optional().Immutable(), + field.String("sourceScope").Optional().Immutable(), + field.String("sourceValue").Optional().Immutable(), + field.Int32("capacity").Optional().Immutable(), + field.String("leakSpeed").Optional().Immutable(), + field.String("scenarioVersion").Optional().Immutable(), + field.String("scenarioHash").Optional().Immutable(), + field.Bool("simulated").Default(false).Immutable(), + field.String("uuid").Optional().Immutable(), // this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each alert + field.Bool("remediation").Optional().Immutable(), } } diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index c3081291254..599c4c404fc 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -3,6 +3,7 @@ package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -16,20 +17,22 @@ func (Bouncer) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"created_at"`), + StructTag(`json:"created_at"`). + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"updated_at"`), - field.String("name").Unique().StructTag(`json:"name"`), - field.String("api_key").StructTag(`json:"api_key"`), // hash of api_key + UpdateDefault(types.UtcNow).StructTag(`json:"updated_at"`), + field.String("name").Unique().StructTag(`json:"name"`).Immutable(), + field.String("api_key").Sensitive(), // hash of api_key field.Bool("revoked").StructTag(`json:"revoked"`), field.String("ip_address").Default("").Optional().StructTag(`json:"ip_address"`), field.String("type").Optional().StructTag(`json:"type"`), field.String("version").Optional().StructTag(`json:"version"`), - field.Time("until").Default(types.UtcNow).Optional().StructTag(`json:"until"`), - field.Time("last_pull"). - Default(types.UtcNow).StructTag(`json:"last_pull"`), + field.Time("last_pull").Nillable().Optional().StructTag(`json:"last_pull"`), field.String("auth_type").StructTag(`json:"auth_type"`).Default(types.ApiKeyAuthType), + field.String("osname").Optional(), + field.String("osversion").Optional(), + field.String("featureflags").Optional(), } } diff --git a/pkg/database/ent/schema/config.go b/pkg/database/ent/schema/config.go index f3320a9cce6..d526db25a8d 100644 --- a/pkg/database/ent/schema/config.go +++ b/pkg/database/ent/schema/config.go @@ -3,6 +3,7 @@ package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -11,21 +12,20 @@ type ConfigItem struct { ent.Schema } -// Fields of the Bouncer. func (ConfigItem) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"created_at"`), + Immutable(). + StructTag(`json:"created_at"`), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional().StructTag(`json:"updated_at"`), - field.String("name").Unique().StructTag(`json:"name"`), + UpdateDefault(types.UtcNow).StructTag(`json:"updated_at"`), + field.String("name").Unique().StructTag(`json:"name"`).Immutable(), field.String("value").StructTag(`json:"value"`), // a json object } } -// Edges of the Bouncer. func (ConfigItem) Edges() []ent.Edge { return nil } diff --git a/pkg/database/ent/schema/decision.go b/pkg/database/ent/schema/decision.go index b7a99fb7a70..4089be38096 100644 --- a/pkg/database/ent/schema/decision.go +++ b/pkg/database/ent/schema/decision.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,25 +20,25 @@ func (Decision) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + UpdateDefault(types.UtcNow), field.Time("until").Nillable().Optional().SchemaType(map[string]string{ dialect.MySQL: "datetime", }), - field.String("scenario"), - field.String("type"), - field.Int64("start_ip").Optional(), - field.Int64("end_ip").Optional(), - field.Int64("start_suffix").Optional(), - field.Int64("end_suffix").Optional(), - field.Int64("ip_size").Optional(), - field.String("scope"), - field.String("value"), - field.String("origin"), - field.Bool("simulated").Default(false), - field.String("uuid").Optional(), //this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each decision + field.String("scenario").Immutable(), + field.String("type").Immutable(), + field.Int64("start_ip").Optional().Immutable(), + field.Int64("end_ip").Optional().Immutable(), + field.Int64("start_suffix").Optional().Immutable(), + field.Int64("end_suffix").Optional().Immutable(), + field.Int64("ip_size").Optional().Immutable(), + field.String("scope").Immutable(), + field.String("value").Immutable(), + field.String("origin").Immutable(), + field.Bool("simulated").Default(false).Immutable(), + field.String("uuid").Optional().Immutable(), // this uuid is mostly here to ensure that CAPI/PAPI has a unique id for each decision field.Int("alert_decisions").Optional(), } } diff --git a/pkg/database/ent/schema/event.go b/pkg/database/ent/schema/event.go index 6b6d2733ff7..107f68e5274 100644 --- a/pkg/database/ent/schema/event.go +++ b/pkg/database/ent/schema/event.go @@ -5,6 +5,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -18,12 +19,12 @@ func (Event) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.Time("time"), - field.String("serialized").MaxLen(8191), + UpdateDefault(types.UtcNow), + field.Time("time").Immutable(), + field.String("serialized").MaxLen(8191).Immutable(), field.Int("alert_events").Optional(), } } diff --git a/pkg/database/ent/schema/lock.go b/pkg/database/ent/schema/lock.go new file mode 100644 index 00000000000..a287e2b59ad --- /dev/null +++ b/pkg/database/ent/schema/lock.go @@ -0,0 +1,23 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +type Lock struct { + ent.Schema +} + +func (Lock) Fields() []ent.Field { + return []ent.Field{ + field.String("name").Unique().Immutable().StructTag(`json:"name"`), + field.Time("created_at").Default(types.UtcNow).StructTag(`json:"created_at"`).Immutable(), + } +} + +func (Lock) Edges() []ent.Edge { + return nil +} diff --git a/pkg/database/ent/schema/machine.go b/pkg/database/ent/schema/machine.go index e155c936071..5b68f61b1a0 100644 --- a/pkg/database/ent/schema/machine.go +++ b/pkg/database/ent/schema/machine.go @@ -4,9 +4,17 @@ import ( "entgo.io/ent" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/types" ) +// ItemState is defined here instead of using pkg/models/HubItem to avoid introducing a dependency +type ItemState struct { + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + Version string `json:"version,omitempty"` +} + // Machine holds the schema definition for the Machine entity. type Machine struct { ent.Schema @@ -17,25 +25,30 @@ func (Machine) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + UpdateDefault(types.UtcNow), field.Time("last_push"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Nillable().Optional(), field.Time("last_heartbeat"). - Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("machineId").Unique(), + Nillable().Optional(), + field.String("machineId"). + Unique(). + Immutable(), field.String("password").Sensitive(), field.String("ipAddress"), field.String("scenarios").MaxLen(100000).Optional(), field.String("version").Optional(), field.Bool("isValidated"). Default(false), - field.String("status").Optional(), field.String("auth_type").Default(types.PasswordAuthType).StructTag(`json:"auth_type"`), + field.String("osname").Optional(), + field.String("osversion").Optional(), + field.String("featureflags").Optional(), + field.JSON("hubstate", map[string][]ItemState{}).Optional(), + field.JSON("datasources", map[string]int64{}).Optional(), } } diff --git a/pkg/database/ent/schema/meta.go b/pkg/database/ent/schema/meta.go index 1a84bb1b667..a87010cd8a3 100644 --- a/pkg/database/ent/schema/meta.go +++ b/pkg/database/ent/schema/meta.go @@ -5,6 +5,7 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" + "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -17,13 +18,12 @@ type Meta struct { func (Meta) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). - Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), + Default(types.UtcNow).Immutable(), field.Time("updated_at"). Default(types.UtcNow). - UpdateDefault(types.UtcNow).Nillable().Optional(), - field.String("key"), - field.String("value").MaxLen(4095), + UpdateDefault(types.UtcNow), + field.String("key").Immutable(), + field.String("value").MaxLen(4095).Immutable(), field.Int("alert_metas").Optional(), } } diff --git a/pkg/database/ent/schema/metric.go b/pkg/database/ent/schema/metric.go new file mode 100644 index 00000000000..319c67b7aa7 --- /dev/null +++ b/pkg/database/ent/schema/metric.go @@ -0,0 +1,34 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" +) + +// Metric is actually a set of metrics collected by a device +// (logprocessor, bouncer, etc) at a given time. +type Metric struct { + ent.Schema +} + +func (Metric) Fields() []ent.Field { + return []ent.Field{ + field.Enum("generated_type"). + Values("LP", "RC"). + Immutable(). + Comment("Type of the metrics source: LP=logprocessor, RC=remediation"), + field.String("generated_by"). + Immutable(). + Comment("Source of the metrics: machine id, bouncer name...\nIt must come from the auth middleware."), + field.Time("received_at"). + Immutable(). + Comment("When the metrics are received by LAPI"), + field.Time("pushed_at"). + Nillable(). + Optional(). + Comment("When the metrics are sent to the console"), + field.Text("payload"). + Immutable(). + Comment("The actual metrics (item0)"), + } +} diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go index 2a1efd152a0..bf8221ce4a5 100644 --- a/pkg/database/ent/tx.go +++ b/pkg/database/ent/tx.go @@ -22,20 +22,18 @@ type Tx struct { Decision *DecisionClient // Event is the client for interacting with the Event builders. Event *EventClient + // Lock is the client for interacting with the Lock builders. + Lock *LockClient // Machine is the client for interacting with the Machine builders. Machine *MachineClient // Meta is the client for interacting with the Meta builders. Meta *MetaClient + // Metric is the client for interacting with the Metric builders. + Metric *MetricClient // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -80,9 +78,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -91,9 +89,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -135,9 +134,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -146,9 +145,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -166,8 +166,10 @@ func (tx *Tx) init() { tx.ConfigItem = NewConfigItemClient(tx.config) tx.Decision = NewDecisionClient(tx.config) tx.Event = NewEventClient(tx.config) + tx.Lock = NewLockClient(tx.config) tx.Machine = NewMachineClient(tx.config) tx.Meta = NewMetaClient(tx.config) + tx.Metric = NewMetricClient(tx.config) } // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. @@ -186,6 +188,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. diff --git a/pkg/database/errors.go b/pkg/database/errors.go index 8e96f52d7ce..77f92707e51 100644 --- a/pkg/database/errors.go +++ b/pkg/database/errors.go @@ -13,8 +13,8 @@ var ( ItemNotFound = errors.New("object not found") ParseTimeFail = errors.New("unable to parse time") ParseDurationFail = errors.New("unable to parse duration") - MarshalFail = errors.New("unable to marshal") - UnmarshalFail = errors.New("unable to unmarshal") + MarshalFail = errors.New("unable to serialize") + UnmarshalFail = errors.New("unable to parse") BulkError = errors.New("unable to insert bulk") ParseType = errors.New("unable to parse type") InvalidIPOrRange = errors.New("invalid ip address / range") diff --git a/pkg/database/flush.go b/pkg/database/flush.go index a7b364fa970..8f646ddc961 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -1,38 +1,52 @@ package database import ( + "context" + "errors" "fmt" "time" "github.com/go-co-op/gocron" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" "github.com/crowdsecurity/crowdsec/pkg/types" ) +const ( + // how long to keep metrics in the local database + defaultMetricsMaxAge = 7 * 24 * time.Hour + flushInterval = 1 * time.Minute +) -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { maxItems := 0 maxAge := "" + if config.MaxItems != nil && *config.MaxItems <= 0 { - return nil, fmt.Errorf("max_items can't be zero or negative number") + return nil, errors.New("max_items can't be zero or negative") } + if config.MaxItems != nil { maxItems = *config.MaxItems } + if config.MaxAge != nil && *config.MaxAge != "" { maxAge = *config.MaxAge } // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) if err != nil { return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) } @@ -45,166 +59,176 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched if err != nil { return nil, fmt.Errorf("while parsing agents cert auto-delete duration: %w", err) } + config.AgentsGC.CertDuration = &duration } + if config.AgentsGC.LoginPassword != nil { duration, err := ParseDuration(*config.AgentsGC.LoginPassword) if err != nil { return nil, fmt.Errorf("while parsing agents login/password auto-delete duration: %w", err) } + config.AgentsGC.LoginPasswordDuration = &duration } + if config.AgentsGC.Api != nil { log.Warning("agents auto-delete for API auth is not supported (use cert or login_password)") } } + if config.BouncersGC != nil { if config.BouncersGC.Cert != nil { duration, err := ParseDuration(*config.BouncersGC.Cert) if err != nil { return nil, fmt.Errorf("while parsing bouncers cert auto-delete duration: %w", err) } + config.BouncersGC.CertDuration = &duration } + if config.BouncersGC.Api != nil { duration, err := ParseDuration(*config.BouncersGC.Api) if err != nil { return nil, fmt.Errorf("while parsing bouncers api auto-delete duration: %w", err) } + config.BouncersGC.ApiDuration = &duration } + if config.BouncersGC.LoginPassword != nil { log.Warning("bouncers auto-delete for login/password auth is not supported (use cert or api)") } } - baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + + baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) if err != nil { return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) } baJob.SingletonMode() + + metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, ctx, config.MetricsMaxAge) + if err != nil { + return nil, fmt.Errorf("while starting flushMetrics scheduler: %w", err) + } + + metricsJob.SingletonMode() + scheduler.StartAsync() return scheduler, nil } +// flushMetrics deletes metrics older than maxAge, regardless if they have been pushed to CAPI or not +func (c *Client) flushMetrics(ctx context.Context, maxAge *time.Duration) { + if maxAge == nil { + maxAge = ptr.Of(defaultMetricsMaxAge) + } + + c.Log.Debugf("flushing metrics older than %s", maxAge) + + deleted, err := c.Ent.Metric.Delete().Where( + metric.ReceivedAtLTE(time.Now().UTC().Add(-*maxAge)), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while flushing metrics: %s", err) + return + } + + if deleted > 0 { + c.Log.Debugf("flushed %d metrics snapshots", deleted) + } +} -func (c *Client) FlushOrphans() { +func (c *Client) FlushOrphans(ctx context.Context) { /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan events: %s", err) return } + if eventsCount > 0 { c.Log.Infof("%d deleted orphan events", eventsCount) } eventsCount, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) - + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan decisions: %s", err) return } + if eventsCount > 0 { c.Log.Infof("%d deleted orphan decisions", eventsCount) } } -func (c *Client) flushBouncers(bouncersCfg *csconfig.AuthGCCfg) { - if bouncersCfg == nil { +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { + if duration == nil { return } - if bouncersCfg.ApiDuration != nil { - log.Debug("trying to delete old bouncers from api") - - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.ApiDuration)), - ).Where( - bouncer.AuthTypeEQ(types.ApiKeyAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key): %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } + count, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), + ).Where( + bouncer.AuthTypeEQ(authType), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) + return } - if bouncersCfg.CertDuration != nil { - log.Debug("trying to delete old bouncers from cert") - - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.CertDuration)), - ).Where( - bouncer.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key): %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } + if count > 0 { + c.Log.Infof("deleted %d expired bouncers (%s)", count, authType) } } -func (c *Client) flushAgents(agentsCfg *csconfig.AuthGCCfg) { - if agentsCfg == nil { +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { + if duration == nil { return } - if agentsCfg.CertDuration != nil { - log.Debug("trying to delete old agents from cert") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.CertDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (cert): %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount) - } + count, err := c.Ent.Machine.Delete().Where( + machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), + machine.Not(machine.HasAlerts()), + machine.AuthTypeEQ(authType), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) + return } - if agentsCfg.LoginPasswordDuration != nil { - log.Debug("trying to delete old agents from password") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.LoginPasswordDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.PasswordAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (password): %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (password auth)", deletionCount) - } + if count > 0 { + c.Log.Infof("deleted %d expired machines (%s auth)", count, authType) } } -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { log.Debug("starting FlushAgentsAndBouncers") - c.flushBouncers(bouncersCfg) - c.flushAgents(agentsCfg) + if agentsCfg != nil { + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + } + + if bouncersCfg != nil { + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) + } return nil } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { - var deletedByAge int - var deletedByNbItem int - var totalAlerts int - var err error +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { + var ( + deletedByAge int + deletedByNbItem int + totalAlerts int + err error + ) if !c.CanFlush { c.Log.Debug("a list is being imported, flushing later") @@ -212,20 +236,23 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() + c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() + + totalAlerts, err = c.TotalAlerts(ctx) if err != nil { c.Log.Warningf("FlushAlerts (max items count): %s", err) return fmt.Errorf("unable to get alerts count: %w", err) } c.Log.Debugf("FlushAlerts (Total alerts): %d", totalAlerts) + if MaxAge != "" { filter := map[string][]string{ "created_before": {MaxAge}, } - nbDeleted, err := c.DeleteAlertWithFilter(filter) + + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) if err != nil { c.Log.Warningf("FlushAlerts (max age): %s", err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) @@ -234,19 +261,21 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { c.Log.Debugf("FlushAlerts (deleted max age alerts): %d", nbDeleted) deletedByAge = nbDeleted } + if MaxItems > 0 { - //We get the highest id for the alerts - //We subtract MaxItems to avoid deleting alerts that are not old enough - //This gives us the oldest alert that we want to keep - //We then delete all the alerts with an id lower than this one - //We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + // We get the highest id for the alerts + // We subtract MaxItems to avoid deleting alerts that are not old enough + // This gives us the oldest alert that we want to keep + // We then delete all the alerts with an id lower than this one + // We can do this because the id is auto-increment, and the database won't reuse the same id twice + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ "sort": {"DESC"}, "limit": {"1"}, - //we do not care about fetching the edges, we just want the id + // we do not care about fetching the edges, we just want the id "with_decisions": {"false"}, }) c.Log.Debugf("FlushAlerts (last alert): %+v", lastAlert) + if err != nil { c.Log.Errorf("FlushAlerts: could not get last alert: %s", err) return fmt.Errorf("could not get last alert: %w", err) @@ -258,9 +287,8 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { c.Log.Debugf("FlushAlerts (max id): %d", maxid) if maxid > 0 { - //This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) - + // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) if err != nil { c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) return fmt.Errorf("could not delete alerts: %w", err) @@ -268,11 +296,16 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } } } + if deletedByNbItem > 0 { - c.Log.Infof("flushed %d/%d alerts because the max number of alerts has been reached (%d max)", deletedByNbItem, totalAlerts, MaxItems) + c.Log.Infof("flushed %d/%d alerts because the max number of alerts has been reached (%d max)", + deletedByNbItem, totalAlerts, MaxItems) } + if deletedByAge > 0 { - c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", deletedByAge, totalAlerts, MaxAge) + c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", + deletedByAge, totalAlerts, MaxAge) } + return nil } diff --git a/pkg/database/lock.go b/pkg/database/lock.go new file mode 100644 index 00000000000..474228a069c --- /dev/null +++ b/pkg/database/lock.go @@ -0,0 +1,87 @@ +package database + +import ( + "context" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/lock" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +const ( + CAPIPullLockTimeout = 10 + CapiPullLockName = "pullCAPI" +) + +func (c *Client) AcquireLock(ctx context.Context, name string) error { + log.Debugf("acquiring lock %s", name) + _, err := c.Ent.Lock.Create(). + SetName(name). + SetCreatedAt(types.UtcNow()). + Save(ctx) + + if ent.IsConstraintError(err) { + return err + } + + if err != nil { + return errors.Wrapf(InsertFail, "insert lock: %s", err) + } + + return nil +} + +func (c *Client) ReleaseLock(ctx context.Context, name string) error { + log.Debugf("releasing lock %s", name) + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} + +func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error { + log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout) + + _, err := c.Ent.Lock.Delete().Where( + lock.NameEQ(name), + lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)), + ).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} + +func (c *Client) IsLocked(err error) bool { + return ent.IsConstraintError(err) +} + +func (c *Client) AcquirePullCAPILock(ctx context.Context) error { + // delete orphan "old" lock if present + err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout) + if err != nil { + log.Errorf("unable to release pullCAPI lock: %s", err) + } + + return c.AcquireLock(ctx, CapiPullLockName) +} + +func (c *Client) ReleasePullCAPILock(ctx context.Context) error { + log.Debugf("deleting lock %s", CapiPullLockName) + + _, err := c.Ent.Lock.Delete().Where( + lock.NameEQ(CapiPullLockName), + ).Exec(ctx) + if err != nil { + return errors.Wrapf(DeleteFail, "delete lock: %s", err) + } + + return nil +} diff --git a/pkg/database/machines.go b/pkg/database/machines.go index b9834e57e09..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -1,7 +1,9 @@ package database import ( + "context" "fmt" + "strings" "time" "github.com/go-openapi/strfmt" @@ -10,13 +12,67 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/schema" + "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) -const CapiMachineID = types.CAPIOrigin -const CapiListsMachineID = types.ListOrigin +const ( + CapiMachineID = types.CAPIOrigin + CapiListsMachineID = types.ListOrigin +) + +type MachineNotFoundError struct { + MachineID string +} + +func (e *MachineNotFoundError) Error() string { + return fmt.Sprintf("'%s' does not exist", e.MachineID) +} + +func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { + os := baseMetrics.Os + features := strings.Join(baseMetrics.FeatureFlags, ",") + + var heartbeat time.Time + + if len(baseMetrics.Metrics) == 0 { + heartbeat = time.Now().UTC() + } else { + heartbeat = time.Unix(*baseMetrics.Metrics[0].Meta.UtcNowTimestamp, 0) + } + + hubState := map[string][]schema.ItemState{} + for itemType, items := range hubItems { + hubState[itemType] = []schema.ItemState{} + for _, item := range items { + hubState[itemType] = append(hubState[itemType], schema.ItemState{ + Name: item.Name, + Status: item.Status, + Version: item.Version, + }) + } + } + + _, err := c.Ent.Machine. + Update(). + Where(machine.MachineIdEQ(machineID)). + SetNillableVersion(baseMetrics.Version). + SetOsname(*os.Name). + SetOsversion(*os.Version). + SetFeatureflags(features). + SetLastHeartbeat(heartbeat). + SetHubstate(hubState). + SetDatasources(datasources). + Save(ctx) + if err != nil { + return fmt.Errorf("unable to update base machine metrics in database: %w", err) + } -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { + return nil +} + +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine: %s", err) @@ -26,23 +82,27 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } + return machine, nil } + return nil, errors.Wrapf(UserExists, "user '%s'", *machineID) } @@ -53,8 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) - + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) @@ -63,140 +122,146 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) } + return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } + return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } + if rets == 0 { - return fmt.Errorf("machine not found") + return errors.New("machine not found") } + return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - var machines []*ent.Machine - var err error - - machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) } + return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } if nbDeleted == 0 { - return fmt.Errorf("machine doesn't exist") + return &MachineNotFoundError{MachineID: name} } return nil } -func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { ids := make([]int, len(machines)) for i, b := range machines { ids[i] = b.ID } - nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) + + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, err } - return nbDeleted, nil -} -func (c *Client) UpdateMachineLastPush(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastPush(time.Now().UTC()).Save(c.CTX) - if err != nil { - return errors.Wrapf(UpdateFail, "updating machine last_push: %s", err) - } - return nil + return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } + return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine in database: %s", err) + return fmt.Errorf("unable to update machine in database: %w", err) } + return nil } -func (c *Client) UpdateMachineIP(ipAddr string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine IP in database: %s", err) + return fmt.Errorf("unable to update machine IP in database: %w", err) } + return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, ID int) error { - _, err := c.Ent.Machine.UpdateOneID(ID). +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { - return fmt.Errorf("unable to update machine version in database: %s", err) + return fmt.Errorf("unable to update machine version in database: %w", err) } + return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } + if len(exist) == 1 { return true, nil } + if len(exist) > 1 { - return false, fmt.Errorf("more than one item with the same machineID in database") + return false, errors.New("more than one item with the same machineID in database") } return false, nil - } -func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) { - return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX) +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { + return c.Ent.Machine.Query().Where( + machine.Or( + machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), + machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), + ), + ).All(ctx) } diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go new file mode 100644 index 00000000000..eb4c472821e --- /dev/null +++ b/pkg/database/metrics.go @@ -0,0 +1,71 @@ +package database + +import ( + "context" + "fmt" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" +) + +func (c *Client) CreateMetric(ctx context.Context, generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { + metric, err := c.Ent.Metric. + Create(). + SetGeneratedType(generatedType). + SetGeneratedBy(generatedBy). + SetReceivedAt(receivedAt). + SetPayload(payload). + Save(ctx) + if err != nil { + c.Log.Warningf("CreateMetric: %s", err) + return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail) + } + + return metric, nil +} + +func (c *Client) GetLPUsageMetricsByMachineID(ctx context.Context, machineId string) ([]*ent.Metric, error) { + metrics, err := c.Ent.Metric.Query(). + Where( + metric.GeneratedTypeEQ(metric.GeneratedTypeLP), + metric.GeneratedByEQ(machineId), + metric.PushedAtIsNil(), + ). + All(ctx) + if err != nil { + c.Log.Warningf("GetLPUsageMetricsByOrigin: %s", err) + return nil, fmt.Errorf("getting LP usage metrics by origin %s: %w", machineId, err) + } + + return metrics, nil +} + +func (c *Client) GetBouncerUsageMetricsByName(ctx context.Context, bouncerName string) ([]*ent.Metric, error) { + metrics, err := c.Ent.Metric.Query(). + Where( + metric.GeneratedTypeEQ(metric.GeneratedTypeRC), + metric.GeneratedByEQ(bouncerName), + metric.PushedAtIsNil(), + ). + All(ctx) + if err != nil { + c.Log.Warningf("GetBouncerUsageMetricsByName: %s", err) + return nil, fmt.Errorf("getting bouncer usage metrics by name %s: %w", bouncerName, err) + } + + return metrics, nil +} + +func (c *Client) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + _, err := c.Ent.Metric.Update(). + Where(metric.IDIn(ids...)). + SetPushedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + c.Log.Warningf("MarkUsageMetricsAsSent: %s", err) + return fmt.Errorf("marking usage metrics as sent: %w", err) + } + + return nil +} diff --git a/pkg/database/utils.go b/pkg/database/utils.go index 2414e702786..8148df56f24 100644 --- a/pkg/database/utils.go +++ b/pkg/database/utils.go @@ -13,12 +13,14 @@ func IP2Int(ip net.IP) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) } + return binary.BigEndian.Uint32(ip) } func Int2ip(nn uint32) net.IP { ip := make(net.IP, 4) binary.BigEndian.PutUint32(ip, nn) + return ip } @@ -26,20 +28,22 @@ func IsIpv4(host string) bool { return net.ParseIP(host) != nil } -//Stolen from : https://github.com/llimllib/ipaddress/ +// Stolen from : https://github.com/llimllib/ipaddress/ // Return the final address of a net range. Convert to IPv4 if possible, // otherwise return an ipv6 func LastAddress(n *net.IPNet) net.IP { ip := n.IP.To4() if ip == nil { ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], ip[6] | ^n.Mask[6], ip[7] | ^n.Mask[7], ip[8] | ^n.Mask[8], ip[9] | ^n.Mask[9], ip[10] | ^n.Mask[10], ip[11] | ^n.Mask[11], ip[12] | ^n.Mask[12], ip[13] | ^n.Mask[13], ip[14] | ^n.Mask[14], - ip[15] | ^n.Mask[15]} + ip[15] | ^n.Mask[15], + } } return net.IPv4( @@ -49,40 +53,44 @@ func LastAddress(n *net.IPNet) net.IP { ip[3]|^n.Mask[3]) } +// GetIpsFromIpRange takes a CIDR range and returns the start and end IP func GetIpsFromIpRange(host string) (int64, int64, error) { - var ipStart int64 - var ipEnd int64 - var err error - var parsedRange *net.IPNet - - if _, parsedRange, err = net.ParseCIDR(host); err != nil { - return ipStart, ipEnd, fmt.Errorf("'%s' is not a valid CIDR", host) + _, parsedRange, err := net.ParseCIDR(host) + if err != nil { + return 0, 0, fmt.Errorf("'%s' is not a valid CIDR", host) } + if parsedRange == nil { - return ipStart, ipEnd, fmt.Errorf("unable to parse network : %s", err) + return 0, 0, fmt.Errorf("unable to parse network: %w", err) } - ipStart = int64(IP2Int(parsedRange.IP)) - ipEnd = int64(IP2Int(LastAddress(parsedRange))) + + ipStart := int64(IP2Int(parsedRange.IP)) + ipEnd := int64(IP2Int(LastAddress(parsedRange))) return ipStart, ipEnd, nil } func ParseDuration(d string) (time.Duration, error) { durationStr := d + if strings.HasSuffix(d, "d") { days := strings.Split(d, "d")[0] - if len(days) == 0 { + if days == "" { return 0, fmt.Errorf("'%s' can't be parsed as duration", d) } + daysInt, err := strconv.Atoi(days) if err != nil { return 0, err } + durationStr = strconv.Itoa(daysInt*24) + "h" } + duration, err := time.ParseDuration(durationStr) if err != nil { return 0, err } + return duration, nil } diff --git a/pkg/dumps/bucket_dump.go b/pkg/dumps/bucket_dump.go index 5f5ce1c4028..328c581928b 100644 --- a/pkg/dumps/bucket_dump.go +++ b/pkg/dumps/bucket_dump.go @@ -4,8 +4,9 @@ import ( "io" "os" + "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/types" - "gopkg.in/yaml.v2" ) type BucketPourInfo map[string][]types.Event diff --git a/pkg/dumps/parser_dump.go b/pkg/dumps/parser_dump.go index 566b87a0803..bc8f78dc203 100644 --- a/pkg/dumps/parser_dump.go +++ b/pkg/dumps/parser_dump.go @@ -1,6 +1,7 @@ package dumps import ( + "errors" "fmt" "io" "os" @@ -8,13 +9,15 @@ import ( "strings" "time" - "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/crowdsecurity/go-cs-lib/maptools" - "github.com/enescakir/emoji" "github.com/fatih/color" diff "github.com/r3labs/diff/v2" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type ParserResult struct { @@ -56,7 +59,7 @@ func LoadParserDump(filepath string) (*ParserResults, error) { var lastStage string - //Loop over stages to find last successful one with at least one parser + // Loop over stages to find last successful one with at least one parser for i := len(stages) - 2; i >= 0; i-- { if len(pdump[stages[i]]) != 0 { lastStage = stages[i] @@ -73,7 +76,7 @@ func LoadParserDump(filepath string) (*ParserResults, error) { sort.Strings(parsers) if len(parsers) == 0 { - return nil, fmt.Errorf("no parser found. Please install the appropriate parser and retry") + return nil, errors.New("no parser found. Please install the appropriate parser and retry") } lastParser := parsers[len(parsers)-1] @@ -89,80 +92,103 @@ func LoadParserDump(filepath string) (*ParserResults, error) { return &pdump, nil } +type tree struct { + // note : we can use line -> time as the unique identifier (of acquisition) + state map[time.Time]map[string]map[string]ParserResult + assoc map[time.Time]string + parserOrder map[string][]string +} + +func newTree() *tree { + return &tree{ + state: make(map[time.Time]map[string]map[string]ParserResult), + assoc: make(map[time.Time]string), + parserOrder: make(map[string][]string), + } +} + func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) { - //note : we can use line -> time as the unique identifier (of acquisition) - state := make(map[time.Time]map[string]map[string]ParserResult) - assoc := make(map[time.Time]string, 0) - parser_order := make(map[string][]string) + t := newTree() + t.processEvents(parserResults) + t.processBuckets(bucketPour) + t.displayResults(opts) +} +func (t *tree) processEvents(parserResults ParserResults) { for stage, parsers := range parserResults { - //let's process parsers in the order according to idx - parser_order[stage] = make([]string, len(parsers)) + // let's process parsers in the order according to idx + t.parserOrder[stage] = make([]string, len(parsers)) + for pname, parser := range parsers { if len(parser) > 0 { - parser_order[stage][parser[0].Idx-1] = pname + t.parserOrder[stage][parser[0].Idx-1] = pname } } - for _, parser := range parser_order[stage] { + for _, parser := range t.parserOrder[stage] { results := parsers[parser] for _, parserRes := range results { evt := parserRes.Evt - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw + if _, ok := t.state[evt.Line.Time]; !ok { + t.state[evt.Line.Time] = make(map[string]map[string]ParserResult) + t.assoc[evt.Line.Time] = evt.Line.Raw } - if _, ok := state[evt.Line.Time][stage]; !ok { - state[evt.Line.Time][stage] = make(map[string]ParserResult) + if _, ok := t.state[evt.Line.Time][stage]; !ok { + t.state[evt.Line.Time][stage] = make(map[string]ParserResult) } - state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success} + t.state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success} } } } +} +func (t *tree) processBuckets(bucketPour BucketPourInfo) { for bname, evtlist := range bucketPour { for _, evt := range evtlist { if evt.Line.Raw == "" { continue } - //it might be bucket overflow being reprocessed, skip this - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw + // it might be bucket overflow being reprocessed, skip this + if _, ok := t.state[evt.Line.Time]; !ok { + t.state[evt.Line.Time] = make(map[string]map[string]ParserResult) + t.assoc[evt.Line.Time] = evt.Line.Raw } - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if _, ok := state[evt.Line.Time]["buckets"]; !ok { - state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if _, ok := t.state[evt.Line.Time]["buckets"]; !ok { + t.state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) } - state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} + t.state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} } } +} +func (t *tree) displayResults(opts DumpOpts) { yellow := color.New(color.FgYellow).SprintFunc() red := color.New(color.FgRed).SprintFunc() green := color.New(color.FgGreen).SprintFunc() whitelistReason := "" - //get each line - for tstamp, rawstr := range assoc { + + // get each line + for tstamp, rawstr := range t.assoc { if opts.SkipOk { - if _, ok := state[tstamp]["buckets"]["OK"]; ok { + if _, ok := t.state[tstamp]["buckets"]["OK"]; ok { continue } } fmt.Printf("line: %s\n", rawstr) - skeys := make([]string, 0, len(state[tstamp])) + skeys := make([]string, 0, len(t.state[tstamp])) - for k := range state[tstamp] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + for k := range t.state[tstamp] { + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered if k == "buckets" { continue } @@ -176,18 +202,18 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO var prevItem types.Event for _, stage := range skeys { - parsers := state[tstamp][stage] + parsers := t.state[tstamp][stage] sep := "├" presep := "|" fmt.Printf("\t%s %s\n", sep, stage) - for idx, parser := range parser_order[stage] { + for idx, parser := range t.parserOrder[stage] { res := parsers[parser].Success sep := "├" - if idx == len(parser_order[stage])-1 { + if idx == len(t.parserOrder[stage])-1 { sep = "└" } @@ -209,13 +235,14 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO case "update": detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To)) - if change.Path[0] == "Whitelisted" && change.To == true { + if change.Path[0] == "Whitelisted" && change.To == true { //nolint:revive whitelisted = true if whitelistReason == "" { whitelistReason = parsers[parser].Evt.WhitelistReason } } + updated++ case "delete": deleted++ @@ -232,7 +259,7 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO } if updated > 0 { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } @@ -240,7 +267,7 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO } if deleted > 0 { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } @@ -248,7 +275,7 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO } if whitelisted { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } @@ -273,12 +300,12 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO sep := "└" - if len(state[tstamp]["buckets"]) > 0 { + if len(t.state[tstamp]["buckets"]) > 0 { sep = "├" } - //did the event enter the bucket pour phase ? - if _, ok := state[tstamp]["buckets"]["OK"]; ok { + // did the event enter the bucket pour phase ? + if _, ok := t.state[tstamp]["buckets"]["OK"]; ok { fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle) } else if whitelistReason != "" { fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle) @@ -286,16 +313,16 @@ func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpO fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle) } - //now print bucket info - if len(state[tstamp]["buckets"]) > 0 { + // now print bucket info + if len(t.state[tstamp]["buckets"]) > 0 { fmt.Printf("\t├ Scenarios\n") } - bnames := make([]string, 0, len(state[tstamp]["buckets"])) + bnames := make([]string, 0, len(t.state[tstamp]["buckets"])) - for k := range state[tstamp]["buckets"] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + for k := range t.state[tstamp]["buckets"] { + // there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + // we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered if k == "OK" { continue } diff --git a/pkg/emoji/emoji.go b/pkg/emoji/emoji.go new file mode 100644 index 00000000000..51295a85411 --- /dev/null +++ b/pkg/emoji/emoji.go @@ -0,0 +1,14 @@ +package emoji + +const ( + CheckMarkButton = "\u2705" // ✅ + CheckMark = "\u2714\ufe0f" // ✔️ + CrossMark = "\u274c" // ❌ + GreenCircle = "\U0001f7e2" // 🟢 + House = "\U0001f3e0" // 🏠 + Package = "\U0001f4e6" // 📦 + Prohibited = "\U0001f6ab" // 🚫 + QuestionMark = "\u2753" // ❓ + RedCircle = "\U0001f534" // 🔴 + Warning = "\u26a0\ufe0f" // ⚠️ +) diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index 59a239722e3..ccd67b27a49 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -1,14 +1,15 @@ package exprhelpers import ( + "errors" "fmt" "time" "github.com/bluele/gcache" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/cticlient" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" ) var CTIUrl = "https://cti.api.crowdsec.net" @@ -20,7 +21,7 @@ var CTIApiEnabled = false // when hitting quotas or auth errors, we temporarily disable the API var CTIBackOffUntil time.Time -var CTIBackOffDuration time.Duration = 5 * time.Minute +var CTIBackOffDuration = 5 * time.Minute var ctiClient *cticlient.CrowdsecCTIClient @@ -40,15 +41,12 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L } clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - return errors.Wrap(err, "while configuring datasource logger") + return fmt.Errorf("while configuring datasource logger: %w", err) } if LogLevel != nil { clog.SetLevel(*LogLevel) } - customLog := log.Fields{ - "type": "crowdsec-cti", - } - subLogger := clog.WithFields(customLog) + subLogger := clog.WithField("type", "crowdsec-cti") CrowdsecCTIInitCache(*Size, *TTL) ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey(CTIApiKey), cticlient.WithLogger(subLogger)) CTIApiEnabled = true @@ -86,12 +84,11 @@ func CrowdsecCTI(params ...any) (any, error) { if val, err := CTICache.Get(ip); err == nil && val != nil { ctiClient.Logger.Debugf("cti cache fetch for %s", ip) ret, ok := val.(*cticlient.SmokeItem) - if !ok { - ctiClient.Logger.Warningf("CrowdsecCTI: invalid type in cache, removing") - CTICache.Remove(ip) - } else { + if ok { return ret, nil } + ctiClient.Logger.Warningf("CrowdsecCTI: invalid type in cache, removing") + CTICache.Remove(ip) } if !CTIBackOffUntil.IsZero() && time.Now().Before(CTIBackOffUntil) { @@ -115,7 +112,7 @@ func CrowdsecCTI(params ...any) (any, error) { return &cticlient.SmokeItem{}, cticlient.ErrLimit default: ctiClient.Logger.Warnf("CTI API error : %s", err) - return &cticlient.SmokeItem{}, fmt.Errorf("unexpected error : %v", err) + return &cticlient.SmokeItem{}, fmt.Errorf("unexpected error: %w", err) } } diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index fc3a236c561..9f78b932d6d 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -69,7 +69,7 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { } func smokeHandler(req *http.Request) *http.Response { - apiKey := req.Header.Get("x-api-key") + apiKey := req.Header.Get("X-Api-Key") if apiKey != validApiKey { return &http.Response{ StatusCode: http.StatusForbidden, @@ -109,7 +109,7 @@ func smokeHandler(req *http.Request) *http.Response { } } -func TestNillClient(t *testing.T) { +func TestNilClient(t *testing.T) { defer ShutdownCrowdsecCTI() if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { @@ -118,7 +118,7 @@ func TestNillClient(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, err, cticlient.ErrDisabled) - assert.Equal(t, item, &cticlient.SmokeItem{}) + assert.Equal(t, &cticlient.SmokeItem{}, item) } func TestInvalidAuth(t *testing.T) { @@ -133,7 +133,7 @@ func TestInvalidAuth(t *testing.T) { })) item, err := CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) + assert.Equal(t, &cticlient.SmokeItem{}, item) assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrUnauthorized) @@ -143,7 +143,7 @@ func TestInvalidAuth(t *testing.T) { })) item, err = CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) + assert.Equal(t, &cticlient.SmokeItem{}, item) assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } @@ -159,7 +159,7 @@ func TestNoKey(t *testing.T) { })) item, err := CrowdsecCTI("1.2.3.4") - assert.Equal(t, item, &cticlient.SmokeItem{}) + assert.Equal(t, &cticlient.SmokeItem{}, item) assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } diff --git a/pkg/exprhelpers/debugger.go b/pkg/exprhelpers/debugger.go index 432bb737eae..2e47af6d1de 100644 --- a/pkg/exprhelpers/debugger.go +++ b/pkg/exprhelpers/debugger.go @@ -5,8 +5,9 @@ import ( "strconv" "strings" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" ) @@ -52,9 +53,8 @@ type OpOutput struct { } func (o *OpOutput) String() string { - ret := fmt.Sprintf("%*c", o.CodeDepth, ' ') - if len(o.Code) != 0 { + if o.Code != "" { ret += fmt.Sprintf("[%s]", o.Code) } ret += " " @@ -69,7 +69,7 @@ func (o *OpOutput) String() string { indent = 0 } ret = fmt.Sprintf("%*cBLOCK_END [%s]", indent, ' ', o.Code) - if len(o.StrConditionResult) > 0 { + if o.StrConditionResult != "" { ret += fmt.Sprintf(" -> %s", o.StrConditionResult) } return ret @@ -106,62 +106,30 @@ func (o *OpOutput) String() string { return ret + "" } -func (erp ExprRuntimeDebug) extractCode(ip int, program *vm.Program, parts []string) string { +func (erp ExprRuntimeDebug) extractCode(ip int, program *vm.Program) string { + locations := program.Locations() + src := string(program.Source()) - //log.Tracef("# extracting code for ip %d [%s]", ip, parts[1]) - if program.Locations[ip].Line == 0 { //it seems line is zero when it's not actual code (ie. op push at the beginning) - log.Tracef("zero location ?") - return "" - } - startLine := program.Locations[ip].Line - startColumn := program.Locations[ip].Column - lines := strings.Split(program.Source.Content(), "\n") + currentInstruction := locations[ip] - endCol := 0 - endLine := 0 + var closest *file.Location - for i := ip + 1; i < len(program.Locations); i++ { - if program.Locations[i].Line > startLine || (program.Locations[i].Line == startLine && program.Locations[i].Column > startColumn) { - //we didn't had values yet and it's superior to current one, take it - if endLine == 0 && endCol == 0 { - endLine = program.Locations[i].Line - endCol = program.Locations[i].Column + for i := ip + 1; i < len(locations); i++ { + if locations[i].From > currentInstruction.From { + if closest == nil || locations[i].From < closest.From { + closest = &locations[i] } - //however, we are looking for the closest upper one - if program.Locations[i].Line < endLine || (program.Locations[i].Line == endLine && program.Locations[i].Column < endCol) { - endLine = program.Locations[i].Line - endCol = program.Locations[i].Column - } - } } - //maybe it was the last instruction ? - if endCol == 0 && endLine == 0 { - endLine = len(lines) - endCol = len(lines[endLine-1]) - } - code_snippet := "" - startLine -= 1 //line count starts at 1 - endLine -= 1 - for i := startLine; i <= endLine; i++ { - if i == startLine { - if startLine != endLine { - code_snippet += lines[i][startColumn:] - continue - } - code_snippet += lines[i][startColumn:endCol] - break - } - if i == endLine { - code_snippet += lines[i][:endCol] - break - } - code_snippet += lines[i] + var end int + if closest == nil { + end = len(src) + } else { + end = closest.From } - log.Tracef("#code extract for ip %d [%s] -> '%s'", ip, parts[1], code_snippet) - return cleanTextForDebug(code_snippet) + return cleanTextForDebug(src[locations[ip].From:end]) } func autoQuote(v any) string { @@ -189,7 +157,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part prevIdxOut = IdxOut - 1 currentDepth = outputs[prevIdxOut].CodeDepth if outputs[prevIdxOut].Func && !outputs[prevIdxOut].Finalized { - stack := vm.Stack() + stack := vm.Stack num_items := 1 for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { outputs[prevIdxOut].FuncResults = append(outputs[prevIdxOut].FuncResults, autoQuote(stack[i])) @@ -197,7 +165,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part } outputs[prevIdxOut].Finalized = true } else if (outputs[prevIdxOut].Comparison || outputs[prevIdxOut].Condition) && !outputs[prevIdxOut].Finalized { - stack := vm.Stack() + stack := vm.Stack outputs[prevIdxOut].StrConditionResult = fmt.Sprintf("%+v", stack) if val, ok := stack[0].(bool); ok { outputs[prevIdxOut].ConditionResult = new(bool) @@ -207,10 +175,10 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part } } - erp.Logger.Tracef("[STEP %d:%s] (stack:%+v) (parts:%+v) {depth:%d}", ip, parts[1], vm.Stack(), parts, currentDepth) + erp.Logger.Tracef("[STEP %d:%s] (stack:%+v) (parts:%+v) {depth:%d}", ip, parts[1], vm.Stack, parts, currentDepth) out := OpOutput{} out.CodeDepth = currentDepth - out.Code = erp.extractCode(ip, program, parts) + out.Code = erp.extractCode(ip, program) switch parts[1] { case "OpBegin": @@ -221,8 +189,8 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part out.CodeDepth -= IndentStep out.BlockEnd = true //OpEnd can carry value, if it's any/all/count etc. - if len(vm.Stack()) > 0 { - out.StrConditionResult = fmt.Sprintf("%v", vm.Stack()) + if len(vm.Stack) > 0 { + out.StrConditionResult = fmt.Sprintf("%v", vm.Stack) } outputs = append(outputs, out) case "OpNot": @@ -241,7 +209,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part out.StrConditionResult = "false" outputs = append(outputs, out) case "OpJumpIfTrue": //OR - stack := vm.Stack() + stack := vm.Stack out.JumpIf = true out.IfTrue = true out.StrConditionResult = fmt.Sprintf("%v", stack[0]) @@ -252,7 +220,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part } outputs = append(outputs, out) case "OpJumpIfFalse": //AND - stack := vm.Stack() + stack := vm.Stack out.JumpIf = true out.IfFalse = true out.StrConditionResult = fmt.Sprintf("%v", stack[0]) @@ -264,7 +232,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part case "OpCall1": //Op for function calls out.Func = true out.FuncName = parts[3] - stack := vm.Stack() + stack := vm.Stack num_items := 1 for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { out.Args = append(out.Args, autoQuote(stack[i])) @@ -274,7 +242,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part case "OpCall2": //Op for function calls out.Func = true out.FuncName = parts[3] - stack := vm.Stack() + stack := vm.Stack num_items := 2 for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { out.Args = append(out.Args, autoQuote(stack[i])) @@ -284,7 +252,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part case "OpCall3": //Op for function calls out.Func = true out.FuncName = parts[3] - stack := vm.Stack() + stack := vm.Stack num_items := 3 for i := len(stack) - 1; i >= 0 && num_items > 0; i-- { out.Args = append(out.Args, autoQuote(stack[i])) @@ -297,7 +265,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part case "OpCallN": //Op for function calls with more than 3 args out.Func = true out.FuncName = parts[1] - stack := vm.Stack() + stack := vm.Stack //for OpCallN, we get the number of args if len(program.Arguments) >= ip { @@ -310,19 +278,19 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part } } } else { //let's blindly take the items on stack - for _, val := range vm.Stack() { + for _, val := range vm.Stack { out.Args = append(out.Args, autoQuote(val)) } } outputs = append(outputs, out) case "OpEqualString", "OpEqual", "OpEqualInt": //comparisons - stack := vm.Stack() + stack := vm.Stack out.Comparison = true out.Left = autoQuote(stack[0]) out.Right = autoQuote(stack[1]) outputs = append(outputs, out) case "OpIn": //in operator - stack := vm.Stack() + stack := vm.Stack out.Condition = true out.ConditionIn = true //seems that we tend to receive stack[1] as a map. @@ -332,7 +300,7 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part out.Args = append(out.Args, autoQuote(stack[1])) outputs = append(outputs, out) case "OpContains": //kind OpIn , but reverse - stack := vm.Stack() + stack := vm.Stack out.Condition = true out.ConditionContains = true //seems that we tend to receive stack[1] as a map. @@ -346,8 +314,11 @@ func (erp ExprRuntimeDebug) ipDebug(ip int, vm *vm.VM, program *vm.Program, part } func (erp ExprRuntimeDebug) ipSeek(ip int) []string { - for i := 0; i < len(erp.Lines); i++ { - parts := strings.Split(erp.Lines[i], "\t") + for i := range len(erp.Lines) { + parts := strings.Fields(erp.Lines[i]) + if len(parts) == 0 { + continue + } if parts[0] == strconv.Itoa(ip) { return parts } @@ -371,7 +342,7 @@ func cleanTextForDebug(text string) string { } func DisplayExprDebug(program *vm.Program, outputs []OpOutput, logger *log.Entry, ret any) { - logger.Debugf("dbg(result=%v): %s", ret, cleanTextForDebug(program.Source.Content())) + logger.Debugf("dbg(result=%v): %s", ret, cleanTextForDebug(string(program.Source()))) for _, output := range outputs { logger.Debugf("%s", output.String()) } @@ -379,62 +350,45 @@ func DisplayExprDebug(program *vm.Program, outputs []OpOutput, logger *log.Entry // TBD: Based on the level of the logger (ie. trace vs debug) we could decide to add more low level instructions (pop, push, etc.) func RunWithDebug(program *vm.Program, env interface{}, logger *log.Entry) ([]OpOutput, any, error) { - - var outputs []OpOutput = []OpOutput{} - var buf strings.Builder - var erp ExprRuntimeDebug = ExprRuntimeDebug{ + outputs := []OpOutput{} + erp := ExprRuntimeDebug{ Logger: logger, } - var debugErr chan error = make(chan error) vm := vm.Debug() - done := false - program.Opcodes(&buf) - lines := strings.Split(buf.String(), "\n") + opcodes := program.Disassemble() + lines := strings.Split(opcodes, "\n") erp.Lines = lines go func() { + //We must never return until the execution of the program is done var err error erp.Logger.Tracef("[START] ip 0") ops := erp.ipSeek(0) if ops == nil { - debugErr <- fmt.Errorf("failed getting ops for ip 0") - return + log.Warningf("error while debugging expr: failed getting ops for ip 0") } if outputs, err = erp.ipDebug(0, vm, program, ops, outputs); err != nil { - debugErr <- fmt.Errorf("error while debugging at ip 0") + log.Warningf("error while debugging expr: error while debugging at ip 0") } vm.Step() for ip := range vm.Position() { ops := erp.ipSeek(ip) - if ops == nil { //we reached the end of the program, we shouldn't throw an error + if ops == nil { erp.Logger.Tracef("[DONE] ip %d", ip) - debugErr <- nil - return + break } if outputs, err = erp.ipDebug(ip, vm, program, ops, outputs); err != nil { - debugErr <- fmt.Errorf("error while debugging at ip %d", ip) - return - } - if done { - debugErr <- nil - return + log.Warningf("error while debugging expr: error while debugging at ip %d", ip) } vm.Step() } - debugErr <- nil }() var return_error error ret, err := vm.Run(program, env) - done = true //if the expr runtime failed, we don't need to wait for the debug to finish if err != nil { return_error = err - } else { - err = <-debugErr - if err != nil { - log.Warningf("error while debugging expr: %s", err) - } } //the overall result of expression is the result of last op ? if len(outputs) > 0 { diff --git a/pkg/exprhelpers/debugger_test.go b/pkg/exprhelpers/debugger_test.go index 9c713a8d4f5..32144454084 100644 --- a/pkg/exprhelpers/debugger_test.go +++ b/pkg/exprhelpers/debugger_test.go @@ -5,10 +5,11 @@ import ( "strings" "testing" - "github.com/antonmedv/expr" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) type ExprDbgTest struct { @@ -25,6 +26,7 @@ type ExprDbgTest struct { func UpperTwo(params ...any) (any, error) { s := params[0].(string) v := params[1].(string) + return strings.ToUpper(s) + strings.ToUpper(v), nil } @@ -32,6 +34,7 @@ func UpperThree(params ...any) (any, error) { s := params[0].(string) v := params[1].(string) x := params[2].(string) + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x), nil } @@ -40,6 +43,7 @@ func UpperN(params ...any) (any, error) { v := params[1].(string) x := params[2].(string) y := params[3].(string) + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x) + strings.ToUpper(y), nil } @@ -51,6 +55,7 @@ type teststruct struct { Foo string } +// You need to add the tag expr_debug when running the tests func TestBaseDbg(t *testing.T) { defaultEnv := map[string]interface{}{ "queue": &types.Queue{}, @@ -59,7 +64,7 @@ func TestBaseDbg(t *testing.T) { "base_string": "hello world", "base_int": 42, "base_float": 42.42, - "nillvar": &teststruct{}, + "nilvar": &teststruct{}, "base_struct": struct { Foo string Bar int @@ -74,13 +79,13 @@ func TestBaseDbg(t *testing.T) { // use '%#v' to dump in golang syntax // use regexp to clear empty/default fields: // [a-z]+: (false|\[\]string\(nil\)|""), - //ConditionResult:(*bool) + // ConditionResult:(*bool) - //Missing multi parametes function + // Missing multi parametes function tests := []ExprDbgTest{ { - Name: "nill deref", - Expr: "Upper('1') == '1' && nillvar.Foo == '42'", + Name: "nil deref", + Expr: "Upper('1') == '1' && nilvar.Foo == '42'", Env: defaultEnv, ExpectedFailRuntime: true, ExpectedOutputs: []OpOutput{ @@ -264,13 +269,13 @@ func TestBaseDbg(t *testing.T) { {Code: "Upper(base_string)", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"hello world\""}, FuncResults: []string{"\"HELLO WORLD\""}, ConditionResult: (*bool)(nil), Finalized: true}, {Code: "Upper('/someotherurl?account-name=admin&account-status=1&ow=cmd') )", CodeDepth: 0, Func: true, FuncName: "Upper", Args: []string{"\"/someotherurl?account-name=admin&account...\""}, FuncResults: []string{"\"/SOMEOTHERURL?ACCOUNT-NAME=ADMIN&ACCOUNT...\""}, ConditionResult: (*bool)(nil), Finalized: true}, {Code: "contains Upper('/someotherurl?account-name=admin&account-status=1&ow=cmd') )", CodeDepth: 0, Args: []string{"\"HELLO WORLD\"", "\"/SOMEOTHERURL?ACCOUNT-NAME=ADMIN&ACCOUNT...\""}, Condition: true, ConditionContains: true, StrConditionResult: "[false]", ConditionResult: boolPtr(false), Finalized: true}, - {Code: "and", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: false}, {Code: "and", CodeDepth: 0, JumpIf: true, IfFalse: true, StrConditionResult: "false", ConditionResult: boolPtr(false), Finalized: true}, }, }, } logger := log.WithField("test", "exprhelpers") + for _, test := range tests { if test.LogLevel != 0 { log.SetLevel(test.LogLevel) @@ -307,10 +312,13 @@ func TestBaseDbg(t *testing.T) { t.Fatalf("test %s : unexpected compile error : %s", test.Name, err) } } - if test.Name == "nill deref" { - test.Env["nillvar"] = nil + + if test.Name == "nil deref" { + test.Env["nilvar"] = nil } + outdbg, ret, err := RunWithDebug(prog, test.Env, logger) + if test.ExpectedFailRuntime { if err == nil { t.Fatalf("test %s : expected runtime error", test.Name) @@ -320,25 +328,30 @@ func TestBaseDbg(t *testing.T) { t.Fatalf("test %s : unexpected runtime error : %s", test.Name, err) } } + log.SetLevel(log.DebugLevel) DisplayExprDebug(prog, outdbg, logger, ret) + if len(outdbg) != len(test.ExpectedOutputs) { t.Errorf("failed test %s", test.Name) t.Errorf("%#v", outdbg) - //out, _ := yaml.Marshal(outdbg) - //fmt.Printf("%s", string(out)) + // out, _ := yaml.Marshal(outdbg) + // fmt.Printf("%s", string(out)) t.Fatalf("test %s : expected %d outputs, got %d", test.Name, len(test.ExpectedOutputs), len(outdbg)) - } + for i, out := range outdbg { - if !reflect.DeepEqual(out, test.ExpectedOutputs[i]) { - spew.Config.DisableMethods = true - t.Errorf("failed test %s", test.Name) - t.Errorf("expected : %#v", test.ExpectedOutputs[i]) - t.Errorf("got : %#v", out) - t.Fatalf("%d/%d : mismatch", i, len(outdbg)) + if reflect.DeepEqual(out, test.ExpectedOutputs[i]) { + // DisplayExprDebug(prog, outdbg, logger, ret) + continue } - //DisplayExprDebug(prog, outdbg, logger, ret) + + spew.Config.DisableMethods = true + + t.Errorf("failed test %s", test.Name) + t.Errorf("expected : %#v", test.ExpectedOutputs[i]) + t.Errorf("got : %#v", out) + t.Fatalf("%d/%d : mismatch", i, len(outdbg)) } } } diff --git a/pkg/exprhelpers/expr_lib.go b/pkg/exprhelpers/expr_lib.go index db191b84a8d..b90c1986153 100644 --- a/pkg/exprhelpers/expr_lib.go +++ b/pkg/exprhelpers/expr_lib.go @@ -1,8 +1,11 @@ package exprhelpers import ( + "net" "time" + "github.com/oschwald/geoip2-golang" + "github.com/crowdsecurity/crowdsec/pkg/cticlient" ) @@ -231,6 +234,20 @@ var exprFuncs = []exprCustomFunc{ new(func(string) int), }, }, + { + name: "GetActiveDecisionsCount", + function: GetActiveDecisionsCount, + signature: []interface{}{ + new(func(string) int), + }, + }, + { + name: "GetActiveDecisionsTimeLeft", + function: GetActiveDecisionsTimeLeft, + signature: []interface{}{ + new(func(string) time.Duration), + }, + }, { name: "GetDecisionsSinceCount", function: GetDecisionsSinceCount, @@ -441,6 +458,41 @@ var exprFuncs = []exprCustomFunc{ new(func(float64, float64) bool), }, }, + { + name: "LibInjectionIsSQLI", + function: LibInjectionIsSQLI, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "LibInjectionIsXSS", + function: LibInjectionIsXSS, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "GeoIPEnrich", + function: GeoIPEnrich, + signature: []interface{}{ + new(func(string) *geoip2.City), + }, + }, + { + name: "GeoIPASNEnrich", + function: GeoIPASNEnrich, + signature: []interface{}{ + new(func(string) *geoip2.ASN), + }, + }, + { + name: "GeoIPRangeEnrich", + function: GeoIPRangeEnrich, + signature: []interface{}{ + new(func(string) *net.IPNet), + }, + }, } //go 1.20 "CutPrefix": strings.CutPrefix, diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 6b9cd15c73b..f2eb208ebfa 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -2,13 +2,12 @@ package exprhelpers import ( "context" - "fmt" + "errors" "os" "testing" "time" - "github.com/antonmedv/expr" - "github.com/pkg/errors" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,9 +21,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - TestFolder = "tests" -) +const TestFolder = "tests" func getDBClient(t *testing.T) *database.Client { t.Helper() @@ -32,7 +29,9 @@ func getDBClient(t *testing.T) *database.Client { dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{ + ctx := context.Background() + + testDBClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), @@ -78,21 +77,21 @@ func TestVisitor(t *testing.T) { name: "debug : can't compile", filter: "static_one.foo.toto == 'lol'", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, { name: "debug : can't compile #2", filter: "static_one.f!oo.to/to == 'lol'", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, { name: "debug : can't compile #3", filter: "", result: false, - err: fmt.Errorf("bad syntax"), + err: errors.New("bad syntax"), env: map[string]interface{}{"static_one": map[string]string{"foo": "bar"}}, }, } @@ -102,13 +101,13 @@ func TestVisitor(t *testing.T) { for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(test.env)...) if err != nil && test.err == nil { - log.Fatalf("compile: %s", err) + t.Fatalf("compile: %s", err) } if compiledFilter != nil { result, err := expr.Run(compiledFilter, test.env) if err != nil && test.err == nil { - log.Fatalf("run : %s", err) + t.Fatalf("run: %s", err) } if isOk := assert.Equal(t, test.result, result); !isOk { @@ -193,14 +192,16 @@ func TestDistanceHelper(t *testing.T) { "lat2": test.lat2, "lon2": test.lon2, } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) if err != nil { t.Fatalf("pattern:%s val:%s NOK %s", test.lat1, test.lon1, err) } + ret, err := expr.Run(vm, env) if test.valid { require.NoError(t, err) - assert.Equal(t, test.dist, ret) + assert.InDelta(t, test.dist, ret, 0.000001) } else { require.Error(t, err) } @@ -216,7 +217,7 @@ func TestRegexpCacheBehavior(t *testing.T) { err = FileInit(TestFolder, filename, "regex") require.NoError(t, err) - //cache with no TTL + // cache with no TTL err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(1)}) require.NoError(t, err) @@ -228,7 +229,7 @@ func TestRegexpCacheBehavior(t *testing.T) { assert.True(t, ret.(bool)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(false)) - //cache with TTL + // cache with TTL ttl := 500 * time.Millisecond err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: ptr.Of(2), TTL: &ttl}) require.NoError(t, err) @@ -243,12 +244,12 @@ func TestRegexpCacheBehavior(t *testing.T) { func TestRegexpInFile(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -286,23 +287,23 @@ func TestRegexpInFile(t *testing.T) { for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { - log.Fatal(err) + t.Fatal(err) } result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { - log.Fatal(err) + t.Fatal(err) } if isOk := assert.Equal(t, test.result, result); !isOk { - t.Fatalf("test '%s' : NOK", test.name) + t.Fatalf("test '%s': NOK", test.name) } } } func TestFileInit(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -340,7 +341,7 @@ func TestFileInit(t *testing.T) { for _, test := range tests { err := FileInit(TestFolder, test.filename, test.types) if err != nil { - log.Fatal(err) + t.Fatal(err) } switch test.types { @@ -376,12 +377,12 @@ func TestFileInit(t *testing.T) { func TestFile(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data.txt", "string") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -419,12 +420,12 @@ func TestFile(t *testing.T) { for _, test := range tests { compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { - log.Fatal(err) + t.Fatal(err) } result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { - log.Fatal(err) + t.Fatal(err) } if isOk := assert.Equal(t, test.result, result); !isOk { @@ -592,7 +593,7 @@ func TestAtof(t *testing.T) { require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) - require.Equal(t, test.result, output) + require.InDelta(t, test.result, output, 0.000001) } } @@ -935,7 +936,7 @@ func TestGetDecisionsCount(t *testing.T) { SaveX(context.Background()) if decision == nil { - require.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } err = Init(dbClient) @@ -995,6 +996,7 @@ func TestGetDecisionsCount(t *testing.T) { log.Printf("test '%s' : OK", test.name) } } + func TestGetDecisionsSinceCount(t *testing.T) { existingIP := "1.2.3.4" unknownIP := "1.2.3.5" @@ -1020,7 +1022,7 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetOrigin("CAPI"). SaveX(context.Background()) if decision == nil { - require.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } decision2 := dbClient.Ent.Decision.Create(). @@ -1039,7 +1041,7 @@ func TestGetDecisionsSinceCount(t *testing.T) { SaveX(context.Background()) if decision2 == nil { - require.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.New("Failed to create sample decision")) } err = Init(dbClient) @@ -1118,6 +1120,268 @@ func TestGetDecisionsSinceCount(t *testing.T) { } } +func TestGetActiveDecisionsCount(t *testing.T) { + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if decision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + expiredDecision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(-time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if expiredDecision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + err = Init(dbClient) + require.NoError(t, err) + + tests := []struct { + name string + env map[string]interface{} + code string + result string + err string + }{ + { + name: "GetActiveDecisionsCount() test: existing IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "Sprintf('%d', GetActiveDecisionsCount(Alert.GetValue()))", + result: "1", + err: "", + }, + { + name: "GetActiveDecisionsCount() test: unknown IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "Sprintf('%d', GetActiveDecisionsCount(Alert.GetValue()))", + result: "0", + err: "", + }, + } + + for _, test := range tests { + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) + require.NoError(t, err) + output, err := expr.Run(program, test.env) + require.NoError(t, err) + require.Equal(t, test.result, output) + log.Printf("test '%s' : OK", test.name) + } +} + +func TestGetActiveDecisionsTimeLeft(t *testing.T) { + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if decision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + longerDecision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().UTC().Add(2 * time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if longerDecision == nil { + require.Error(t, errors.New("Failed to create sample decision")) + } + + err = Init(dbClient) + require.NoError(t, err) + + tests := []struct { + name string + env map[string]interface{} + code string + min float64 + max float64 + err string + }{ + { + name: "GetActiveDecisionsTimeLeft() test: existing IP time left", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue())", + min: 7195, // 5 seconds margin to make sure the test doesn't fail randomly in the CI + max: 7200, + err: "", + }, + { + name: "GetActiveDecisionsTimeLeft() test: unknown IP time left", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue())", + min: 0, + max: 0, + err: "", + }, + { + name: "GetActiveDecisionsTimeLeft() test: existing IP and call time.Duration method", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue()).Hours()", + min: 2, + max: 2, + }, + { + name: "GetActiveDecisionsTimeLeft() test: unknown IP and call time.Duration method", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + }, + code: "GetActiveDecisionsTimeLeft(Alert.GetValue()).Hours()", + min: 0, + max: 0, + }, + } + + delta := 0.001 + + for _, test := range tests { + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) + require.NoError(t, err) + output, err := expr.Run(program, test.env) + require.NoError(t, err) + + switch o := output.(type) { + case time.Duration: + require.LessOrEqual(t, int(o.Seconds()), int(test.max)) + require.GreaterOrEqual(t, int(o.Seconds()), int(test.min)) + case float64: + require.LessOrEqual(t, o, test.max+delta) + require.GreaterOrEqual(t, o, test.min-delta) + default: + t.Fatalf("GetActiveDecisionsTimeLeft() should return a time.Duration or a float64") + } + } +} + func TestParseUnixTime(t *testing.T) { tests := []struct { name string @@ -1128,12 +1392,12 @@ func TestParseUnixTime(t *testing.T) { { name: "ParseUnix() test: valid value with milli", value: "1672239773.3590894", - expected: time.Date(2022, 12, 28, 15, 02, 53, 0, time.UTC), + expected: time.Date(2022, 12, 28, 15, 2, 53, 0, time.UTC), }, { name: "ParseUnix() test: valid value without milli", value: "1672239773", - expected: time.Date(2022, 12, 28, 15, 02, 53, 0, time.UTC), + expected: time.Date(2022, 12, 28, 15, 2, 53, 0, time.UTC), }, { name: "ParseUnix() test: invalid input", @@ -1150,13 +1414,14 @@ func TestParseUnixTime(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { output, err := ParseUnixTime(tc.value) cstest.RequireErrorContains(t, err, tc.expectedErr) + if tc.expectedErr != "" { return } + require.WithinDuration(t, tc.expected, output.(time.Time), time.Second) }) } @@ -1164,7 +1429,7 @@ func TestParseUnixTime(t *testing.T) { func TestIsIp(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -1252,13 +1517,13 @@ func TestIsIp(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { require.Error(t, err) return } + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) require.NoError(t, err) @@ -1304,7 +1569,6 @@ func TestToString(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) require.NoError(t, err) @@ -1351,19 +1615,21 @@ func TestB64Decode(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { require.Error(t, err) return } + require.NoError(t, err) + output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) if tc.expectedRuntimeErr { require.Error(t, err) return } + require.NoError(t, err) require.Equal(t, tc.expected, output) }) @@ -1421,7 +1687,6 @@ func TestParseKv(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { outMap := make(map[string]interface{}) env := map[string]interface{}{ diff --git a/pkg/exprhelpers/geoip.go b/pkg/exprhelpers/geoip.go new file mode 100644 index 00000000000..fb0c344d884 --- /dev/null +++ b/pkg/exprhelpers/geoip.go @@ -0,0 +1,63 @@ +package exprhelpers + +import ( + "net" +) + +func GeoIPEnrich(params ...any) (any, error) { + if geoIPCityReader == nil { + return nil, nil + } + + ip := params[0].(string) + + parsedIP := net.ParseIP(ip) + + city, err := geoIPCityReader.City(parsedIP) + + if err != nil { + return nil, err + } + + return city, nil +} + +func GeoIPASNEnrich(params ...any) (any, error) { + if geoIPASNReader == nil { + return nil, nil + } + + ip := params[0].(string) + + parsedIP := net.ParseIP(ip) + asn, err := geoIPASNReader.ASN(parsedIP) + + if err != nil { + return nil, err + } + + return asn, nil +} + +func GeoIPRangeEnrich(params ...any) (any, error) { + if geoIPRangeReader == nil { + return nil, nil + } + + ip := params[0].(string) + + var dummy interface{} + + parsedIP := net.ParseIP(ip) + rangeIP, ok, err := geoIPRangeReader.LookupNetwork(parsedIP, &dummy) + + if err != nil { + return nil, err + } + + if !ok { + return nil, nil + } + + return rangeIP, nil +} diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 79a621c7d35..9bc991a8f2d 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,7 +2,9 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" + "errors" "fmt" "math" "net" @@ -15,11 +17,13 @@ import ( "strings" "time" - "github.com/antonmedv/expr" "github.com/bluele/gcache" "github.com/c-robinson/iplib" "github.com/cespare/xxhash/v2" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/oschwald/geoip2-golang" + "github.com/oschwald/maxminddb-golang" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "github.com/umahmood/haversine" @@ -33,9 +37,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var dataFile map[string][]string -var dataFileRegex map[string][]*regexp.Regexp -var dataFileRe2 map[string][]*re2.Regexp +var ( + dataFile map[string][]string + dataFileRegex map[string][]*regexp.Regexp + dataFileRe2 map[string][]*re2.Regexp +) // This is used to (optionally) cache regexp results for RegexpInFile operations var dataFileRegexCache map[string]gcache.Cache = make(map[string]gcache.Cache) @@ -55,6 +61,12 @@ var exprFunctionOptions []expr.Option var keyValuePattern = regexp.MustCompile(`(?P[^=\s]+)=(?:"(?P[^"\\]*(?:\\.[^"\\]*)*)"|(?P[^=\s]+)|\s*)`) +var ( + geoIPCityReader *geoip2.Reader + geoIPASNReader *geoip2.Reader + geoIPRangeReader *maxminddb.Reader +) + func GetExprOptions(ctx map[string]interface{}) []expr.Option { if len(exprFunctionOptions) == 0 { exprFunctionOptions = []expr.Option{} @@ -66,32 +78,71 @@ func GetExprOptions(ctx map[string]interface{}) []expr.Option { )) } } + ret := []expr.Option{} ret = append(ret, exprFunctionOptions...) ret = append(ret, expr.Env(ctx)) + return ret } +func GeoIPInit(datadir string) error { + var err error + + geoIPCityReader, err = geoip2.Open(filepath.Join(datadir, "GeoLite2-City.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-City.mmdb : %s", err) + return err + } + + geoIPASNReader, err = geoip2.Open(filepath.Join(datadir, "GeoLite2-ASN.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-ASN.mmdb : %s", err) + return err + } + + geoIPRangeReader, err = maxminddb.Open(filepath.Join(datadir, "GeoLite2-ASN.mmdb")) + if err != nil { + log.Errorf("unable to open GeoLite2-ASN.mmdb : %s", err) + return err + } + + return nil +} + +func GeoIPClose() { + if geoIPCityReader != nil { + geoIPCityReader.Close() + } + + if geoIPASNReader != nil { + geoIPASNReader.Close() + } + + if geoIPRangeReader != nil { + geoIPRangeReader.Close() + } +} + func Init(databaseClient *database.Client) error { dataFile = make(map[string][]string) dataFileRegex = make(map[string][]*regexp.Regexp) dataFileRe2 = make(map[string][]*re2.Regexp) dbClient = databaseClient - + XMLCacheInit() return nil } func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { - - //cache is explicitly disabled + // cache is explicitly disabled if CacheCfg.Cache != nil && !*CacheCfg.Cache { return nil } - //cache is implicitly disabled if no cache config is provided + // cache is implicitly disabled if no cache config is provided if CacheCfg.Strategy == nil && CacheCfg.TTL == nil && CacheCfg.Size == nil { return nil } - //cache is enabled + // cache is enabled if CacheCfg.Size == nil { CacheCfg.Size = ptr.Of(50) @@ -102,6 +153,7 @@ func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { if CacheCfg.Strategy == nil { CacheCfg.Strategy = ptr.Of("LRU") } + switch *CacheCfg.Strategy { case "LRU": gc = gc.LRU() @@ -116,14 +168,17 @@ func RegexpCacheInit(filename string, CacheCfg types.DataSource) error { if CacheCfg.TTL != nil { gc.Expiration(*CacheCfg.TTL) } + cache := gc.Build() dataFileRegexCache[filename] = cache + return nil } // UpdateCacheMetrics is called directly by the prom handler func UpdateRegexpCacheMetrics() { RegexpCacheMetrics.Reset() + for name := range dataFileRegexCache { RegexpCacheMetrics.With(prometheus.Labels{"name": name}).Set(float64(dataFileRegexCache[name].Len(true))) } @@ -131,10 +186,12 @@ func UpdateRegexpCacheMetrics() { func FileInit(fileFolder string, filename string, fileType string) error { log.Debugf("init (folder:%s) (file:%s) (type:%s)", fileFolder, filename, fileType) + if fileType == "" { log.Debugf("ignored file %s%s because no type specified", fileFolder, filename) return nil } + ok, err := existsInFileMaps(filename, fileType) if ok { log.Debugf("ignored file %s%s because already loaded", fileFolder, filename) @@ -145,6 +202,7 @@ func FileInit(fileFolder string, filename string, fileType string) error { } filepath := filepath.Join(fileFolder, filename) + file, err := os.Open(filepath) if err != nil { return err @@ -156,31 +214,29 @@ func FileInit(fileFolder string, filename string, fileType string) error { if strings.HasPrefix(scanner.Text(), "#") { // allow comments continue } - if len(scanner.Text()) == 0 { //skip empty lines + if scanner.Text() == "" { //skip empty lines continue } + switch fileType { case "regex", "regexp": if fflag.Re2RegexpInfileSupport.IsEnabled() { dataFileRe2[filename] = append(dataFileRe2[filename], re2.MustCompile(scanner.Text())) continue } + dataFileRegex[filename] = append(dataFileRegex[filename], regexp.MustCompile(scanner.Text())) case "string": dataFile[filename] = append(dataFile[filename], scanner.Text()) } } - if err := scanner.Err(); err != nil { - return err - } - return nil + return scanner.Err() } // Expr helpers func Distinct(params ...any) (any, error) { - if rt := reflect.TypeOf(params[0]).Kind(); rt != reflect.Slice && rt != reflect.Array { return nil, nil } @@ -189,8 +245,8 @@ func Distinct(params ...any) (any, error) { return []interface{}{}, nil } - var exists map[any]bool = make(map[any]bool) - var ret []interface{} = make([]interface{}, 0) + exists := make(map[any]bool) + ret := make([]interface{}, 0) for _, val := range array { if _, ok := exists[val]; !ok { @@ -199,7 +255,6 @@ func Distinct(params ...any) (any, error) { } } return ret, nil - } func FlattenDistinct(params ...any) (any, error) { @@ -216,7 +271,7 @@ func flatten(args []interface{}, v reflect.Value) []interface{} { } if v.Kind() == reflect.Array || v.Kind() == reflect.Slice { - for i := 0; i < v.Len(); i++ { + for i := range v.Len() { args = flatten(args, v.Index(i)) } } else { @@ -225,6 +280,7 @@ func flatten(args []interface{}, v reflect.Value) []interface{} { return args } + func existsInFileMaps(filename string, ftype string) (bool, error) { ok := false var err error @@ -537,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + + ctx := context.TODO() + + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -550,7 +609,7 @@ func GetDecisionsSinceCount(params ...any) (any, error) { value := params[0].(string) since := params[1].(string) if dbClient == nil { - log.Error("No database config to call GetDecisionsCount()") + log.Error("No database config to call GetDecisionsSinceCount()") return 0, nil } sinceDuration, err := time.ParseDuration(since) @@ -558,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) return 0, nil } + + ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -567,6 +629,36 @@ func GetDecisionsSinceCount(params ...any) (any, error) { return count, nil } +func GetActiveDecisionsCount(params ...any) (any, error) { + value := params[0].(string) + if dbClient == nil { + log.Error("No database config to call GetActiveDecisionsCount()") + return 0, nil + } + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) + if err != nil { + log.Errorf("Failed to get active decisions count from value '%s'", value) + return 0, err + } + return count, nil +} + +func GetActiveDecisionsTimeLeft(params ...any) (any, error) { + value := params[0].(string) + if dbClient == nil { + log.Error("No database config to call GetActiveDecisionsTimeLeft()") + return 0, nil + } + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) + if err != nil { + log.Errorf("Failed to get active decisions time left from value '%s'", value) + return 0, err + } + return timeLeft, nil +} + // func LookupHost(value string) []string { func LookupHost(params ...any) (any, error) { value := params[0].(string) @@ -682,7 +774,6 @@ func B64Decode(params ...any) (any, error) { } func ParseKV(params ...any) (any, error) { - blob := params[0].(string) target := params[1].(map[string]interface{}) prefix := params[2].(string) @@ -690,7 +781,7 @@ func ParseKV(params ...any) (any, error) { matches := keyValuePattern.FindAllStringSubmatch(blob, -1) if matches == nil { log.Errorf("could not find any key/value pair in line") - return nil, fmt.Errorf("invalid input format") + return nil, errors.New("invalid input format") } if _, ok := target[prefix]; !ok { target[prefix] = make(map[string]string) @@ -698,7 +789,7 @@ func ParseKV(params ...any) (any, error) { _, ok := target[prefix].(map[string]string) if !ok { log.Errorf("ParseKV: target is not a map[string]string") - return nil, fmt.Errorf("target is not a map[string]string") + return nil, errors.New("target is not a map[string]string") } } for _, match := range matches { diff --git a/pkg/exprhelpers/jsonextract.go b/pkg/exprhelpers/jsonextract.go index 6edb34e36e6..64ed97873d6 100644 --- a/pkg/exprhelpers/jsonextract.go +++ b/pkg/exprhelpers/jsonextract.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/buger/jsonparser" - log "github.com/sirupsen/logrus" ) @@ -15,11 +14,11 @@ import ( func JsonExtractLib(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].([]string) + value, dataType, _, err := jsonparser.Get( jsonparser.StringToBytes(jsblob), target..., ) - if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { log.Debugf("%+v doesn't exist", target) @@ -93,7 +92,6 @@ func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]by jsonparser.StringToBytes(jsblob), fullpath..., ) - if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { log.Debugf("Key %+v doesn't exist", target) @@ -115,8 +113,8 @@ func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]by func JsonExtractSlice(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].(string) - value, err := jsonExtractType(jsblob, target, jsonparser.Array) + value, err := jsonExtractType(jsblob, target, jsonparser.Array) if err != nil { log.Errorf("JsonExtractSlice : %s", err) return []interface{}(nil), nil @@ -136,8 +134,8 @@ func JsonExtractSlice(params ...any) (any, error) { func JsonExtractObject(params ...any) (any, error) { jsblob := params[0].(string) target := params[1].(string) - value, err := jsonExtractType(jsblob, target, jsonparser.Object) + value, err := jsonExtractType(jsblob, target, jsonparser.Object) if err != nil { log.Errorf("JsonExtractObject: %s", err) return map[string]interface{}(nil), nil diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 1bd45aa2d6a..5845c3ae66b 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -3,21 +3,19 @@ package exprhelpers import ( "testing" - log "github.com/sirupsen/logrus" - - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestJsonExtract(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -67,12 +65,12 @@ func TestJsonExtract(t *testing.T) { func TestJsonExtractUnescape(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -115,12 +113,12 @@ func TestJsonExtractUnescape(t *testing.T) { func TestJsonExtractSlice(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -161,7 +159,6 @@ func TestJsonExtractSlice(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { env := map[string]interface{}{ "blob": test.jsonBlob, @@ -178,12 +175,12 @@ func TestJsonExtractSlice(t *testing.T) { func TestJsonExtractObject(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } err := FileInit(TestFolder, "test_data_re.txt", "regex") if err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -217,7 +214,6 @@ func TestJsonExtractObject(t *testing.T) { } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { env := map[string]interface{}{ "blob": test.jsonBlob, diff --git a/pkg/exprhelpers/libinjection.go b/pkg/exprhelpers/libinjection.go new file mode 100644 index 00000000000..e9f33e4f459 --- /dev/null +++ b/pkg/exprhelpers/libinjection.go @@ -0,0 +1,17 @@ +package exprhelpers + +import "github.com/corazawaf/libinjection-go" + +func LibInjectionIsSQLI(params ...any) (any, error) { + str := params[0].(string) + + ret, _ := libinjection.IsSQLi(str) + return ret, nil +} + +func LibInjectionIsXSS(params ...any) (any, error) { + str := params[0].(string) + + ret := libinjection.IsXSS(str) + return ret, nil +} diff --git a/pkg/exprhelpers/libinjection_test.go b/pkg/exprhelpers/libinjection_test.go new file mode 100644 index 00000000000..7b4ab825db9 --- /dev/null +++ b/pkg/exprhelpers/libinjection_test.go @@ -0,0 +1,60 @@ +package exprhelpers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLibinjectionHelpers(t *testing.T) { + tests := []struct { + name string + function func(params ...any) (any, error) + params []any + expectResult any + }{ + { + name: "LibInjectionIsSQLI", + function: LibInjectionIsSQLI, + params: []any{"?__f__73=73&&__f__75=75&delivery=1&max=24.9&min=15.9&n=12&o=2&p=(select(0)from(select(sleep(15)))v)/*'%2B(select(0)from(select(sleep(15)))v)%2B'\x22%2B(select(0)from(select(sleep(15)))v)%2B\x22*/&rating=4"}, + expectResult: true, + }, + { + name: "LibInjectionIsSQLI - no match", + function: LibInjectionIsSQLI, + params: []any{"?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsSQLI - no match 2", + function: LibInjectionIsSQLI, + params: []any{"https://foo.com/asdkfj?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsXSS", + function: LibInjectionIsXSS, + params: []any{""}, + expectResult: true, + }, + { + name: "LibInjectionIsXSS - no match", + function: LibInjectionIsXSS, + params: []any{"?bla=42&foo=bar"}, + expectResult: false, + }, + { + name: "LibInjectionIsXSS - no match 2", + function: LibInjectionIsXSS, + params: []any{"https://foo.com/asdkfj?bla=42&foo[]=bar&foo"}, + expectResult: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, _ := test.function(test.params...) + assert.Equal(t, test.expectResult, result) + }) + } +} diff --git a/pkg/exprhelpers/xml.go b/pkg/exprhelpers/xml.go index 75758e18316..0b550bdb641 100644 --- a/pkg/exprhelpers/xml.go +++ b/pkg/exprhelpers/xml.go @@ -1,43 +1,103 @@ package exprhelpers import ( + "errors" + "sync" + "time" + "github.com/beevik/etree" + "github.com/bluele/gcache" + "github.com/cespare/xxhash/v2" log "github.com/sirupsen/logrus" ) -var pathCache = make(map[string]etree.Path) +var ( + pathCache = make(map[string]etree.Path) + rwMutex = sync.RWMutex{} + xmlDocumentCache gcache.Cache +) + +func compileOrGetPath(path string) (etree.Path, error) { + rwMutex.RLock() + compiledPath, ok := pathCache[path] + rwMutex.RUnlock() + + if !ok { + var err error + compiledPath, err = etree.CompilePath(path) + if err != nil { + return etree.Path{}, err + } + + rwMutex.Lock() + pathCache[path] = compiledPath + rwMutex.Unlock() + } + + return compiledPath, nil +} + +func getXMLDocumentFromCache(xmlString string) (*etree.Document, error) { + cacheKey := xxhash.Sum64String(xmlString) + cacheObj, err := xmlDocumentCache.Get(cacheKey) + + if err != nil && !errors.Is(err, gcache.KeyNotFoundError) { + return nil, err + } + + doc, ok := cacheObj.(*etree.Document) + if !ok || cacheObj == nil { + doc = etree.NewDocument() + if err := doc.ReadFromString(xmlString); err != nil { + return nil, err + } + if err := xmlDocumentCache.Set(cacheKey, doc); err != nil { + log.Warnf("Could not set XML document in cache: %s", err) + } + } + + return doc, nil +} + +func XMLCacheInit() { + gc := gcache.New(50) + // Short cache expiration because we each line we read is different, but we can call multiple times XML helpers on each of them + gc.Expiration(5 * time.Second) + gc = gc.LRU() + + xmlDocumentCache = gc.Build() +} // func XMLGetAttributeValue(xmlString string, path string, attributeName string) string { func XMLGetAttributeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) attributeName := params[2].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + attr := elem.SelectAttr(attributeName) if attr == nil { log.Debugf("Could not find attribute %s", attributeName) return "", nil } + return attr.Value, nil } @@ -45,26 +105,24 @@ func XMLGetAttributeValue(params ...any) (any, error) { func XMLGetNodeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + return elem.Text(), nil } diff --git a/pkg/exprhelpers/xml_test.go b/pkg/exprhelpers/xml_test.go index 516387f764b..42823884025 100644 --- a/pkg/exprhelpers/xml_test.go +++ b/pkg/exprhelpers/xml_test.go @@ -9,7 +9,7 @@ import ( func TestXMLGetAttributeValue(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -58,17 +58,19 @@ func TestXMLGetAttributeValue(t *testing.T) { for _, test := range tests { result, _ := XMLGetAttributeValue(test.xmlString, test.path, test.attribute) + isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) } + log.Printf("test '%s' : OK", test.name) } - } + func TestXMLGetNodeValue(t *testing.T) { if err := Init(nil); err != nil { - log.Fatal(err) + t.Fatal(err) } tests := []struct { @@ -105,11 +107,12 @@ func TestXMLGetNodeValue(t *testing.T) { for _, test := range tests { result, _ := XMLGetNodeValue(test.xmlString, test.path) + isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) } + log.Printf("test '%s' : OK", test.name) } - } diff --git a/pkg/fflag/features.go b/pkg/fflag/features.go index 3a106984a66..c8a3d7755ea 100644 --- a/pkg/fflag/features.go +++ b/pkg/fflag/features.go @@ -97,7 +97,7 @@ type FeatureRegister struct { features map[string]*Feature } -var featureNameRexp = regexp.MustCompile(`^[a-z0-9_\.]+$`) +var featureNameRexp = regexp.MustCompile(`^[a-z0-9_.]+$`) func validateFeatureName(featureName string) error { if featureName == "" { diff --git a/pkg/fflag/features_test.go b/pkg/fflag/features_test.go index 57745b3c38c..481e86573e8 100644 --- a/pkg/fflag/features_test.go +++ b/pkg/fflag/features_test.go @@ -50,8 +50,6 @@ func TestRegisterFeature(t *testing.T) { } for _, tc := range tests { - tc := tc - t.Run("", func(t *testing.T) { fr := fflag.FeatureRegister{EnvPrefix: "FFLAG_TEST_"} err := fr.RegisterFeature(&tc.feature) @@ -112,7 +110,6 @@ func TestGetFeature(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { _, err := fr.GetFeature(tc.feature) cstest.RequireErrorMessage(t, err, tc.expectedErr) @@ -145,7 +142,6 @@ func TestIsEnabled(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { feat, err := fr.GetFeature(tc.feature) require.NoError(t, err) @@ -204,7 +200,6 @@ func TestFeatureSet(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { feat, err := fr.GetFeature(tc.feature) cstest.RequireErrorMessage(t, err, tc.expectedGetErr) @@ -284,7 +279,6 @@ func TestSetFromEnv(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := logtest.NewNullLogger() logger.SetLevel(logrus.DebugLevel) @@ -344,7 +338,6 @@ func TestSetFromYaml(t *testing.T) { fr := setUp(t) for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { logger, hook := logtest.NewNullLogger() logger.SetLevel(logrus.DebugLevel) diff --git a/pkg/hubtest/appsecrule.go b/pkg/hubtest/appsecrule.go index 9b70e1441ac..1c4416c2e9b 100644 --- a/pkg/hubtest/appsecrule.go +++ b/pkg/hubtest/appsecrule.go @@ -11,75 +11,77 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installAppsecRuleItem(hubAppsecRule *cwhub.Item) error { - appsecRuleSource, err := filepath.Abs(filepath.Join(t.HubPath, hubAppsecRule.RemotePath)) +func (t *HubTestItem) installAppsecRuleItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", appsecRuleSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - appsecRuleFilename := filepath.Base(appsecRuleSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/appsec-rules/author/appsec-rule - hubDirAppsecRuleDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubAppsecRule.RemotePath)) + hubDirAppsecRuleDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/appsec-rules/ - appsecRuleDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) - if err := os.MkdirAll(hubDirAppsecRuleDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirAppsecRuleDest, err) - } - - if err := os.MkdirAll(appsecRuleDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", appsecRuleDirDest, err) + if err := createDirs([]string{hubDirAppsecRuleDest, itemTypeDirDest}); err != nil { + return err } // runtime/hub/appsec-rules/crowdsecurity/rule.yaml - hubDirAppsecRulePath := filepath.Join(appsecRuleDirDest, appsecRuleFilename) - if err := Copy(appsecRuleSource, hubDirAppsecRulePath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", appsecRuleSource, hubDirAppsecRulePath, err) + hubDirAppsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) + if err := Copy(sourcePath, hubDirAppsecRulePath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirAppsecRulePath, err) } // runtime/appsec-rules/rule.yaml - appsecRulePath := filepath.Join(appsecRuleDirDest, appsecRuleFilename) + appsecRulePath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirAppsecRulePath, appsecRulePath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink appsec-rule '%s' to '%s': %s", hubDirAppsecRulePath, appsecRulePath, err) + return fmt.Errorf("unable to symlink appsec-rule '%s' to '%s': %w", hubDirAppsecRulePath, appsecRulePath, err) } } return nil } +func (t *HubTestItem) installAppsecRuleCustomFrom(appsecrule string, customPath string) (bool, error) { + // we check if its a custom appsec-rule + customAppsecRulePath := filepath.Join(customPath, appsecrule) + if _, err := os.Stat(customAppsecRulePath); os.IsNotExist(err) { + return false, nil + } + + customAppsecRulePathSplit := strings.Split(customAppsecRulePath, "/") + customAppsecRuleName := customAppsecRulePathSplit[len(customAppsecRulePathSplit)-1] + + itemTypeDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + customAppsecRuleDest := fmt.Sprintf("%s/appsec-rules/%s", t.RuntimePath, customAppsecRuleName) + if err := Copy(customAppsecRulePath, customAppsecRuleDest); err != nil { + return false, fmt.Errorf("unable to copy appsec-rule from '%s' to '%s': %w", customAppsecRulePath, customAppsecRuleDest, err) + } + + return true, nil +} + func (t *HubTestItem) installAppsecRuleCustom(appsecrule string) error { - customAppsecRuleExist := false for _, customPath := range t.CustomItemsLocation { - // we check if its a custom appsec-rule - customAppsecRulePath := filepath.Join(customPath, appsecrule) - if _, err := os.Stat(customAppsecRulePath); os.IsNotExist(err) { - continue + found, err := t.installAppsecRuleCustomFrom(appsecrule, customPath) + if err != nil { + return err } - customAppsecRulePathSplit := strings.Split(customAppsecRulePath, "/") - customAppsecRuleName := customAppsecRulePathSplit[len(customAppsecRulePathSplit)-1] - appsecRuleDirDest := fmt.Sprintf("%s/appsec-rules/", t.RuntimePath) - if err := os.MkdirAll(appsecRuleDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", appsecRuleDirDest, err) + if found { + return nil } - - // runtime/appsec-rules/ - customAppsecRuleDest := fmt.Sprintf("%s/appsec-rules/%s", t.RuntimePath, customAppsecRuleName) - // if path to postoverflow exist, copy it - if err := Copy(customAppsecRulePath, customAppsecRuleDest); err != nil { - continue - } - customAppsecRuleExist = true - break - } - if !customAppsecRuleExist { - return fmt.Errorf("couldn't find custom appsec-rule '%s' in the following location: %+v", appsecrule, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom appsec-rule '%s' in the following location: %+v", appsecrule, t.CustomItemsLocation) } func (t *HubTestItem) installAppsecRule(name string) error { diff --git a/pkg/hubtest/coverage.go b/pkg/hubtest/coverage.go index dc3d1d13ad2..e42c1e23455 100644 --- a/pkg/hubtest/coverage.go +++ b/pkg/hubtest/coverage.go @@ -2,27 +2,30 @@ package hubtest import ( "bufio" + "errors" "fmt" "os" "path/filepath" "strings" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/go-cs-lib/maptools" - log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" ) type Coverage struct { Name string TestsCount int - PresentIn map[string]bool //poorman's set + PresentIn map[string]bool // poorman's set } func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { if len(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES)) == 0 { - return nil, fmt.Errorf("no appsec rules in hub index") + return nil, errors.New("no appsec rules in hub index") } // populate from hub, iterate in alphabetical order @@ -40,31 +43,36 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { // parser the expressions a-la-oneagain appsecTestConfigs, err := filepath.Glob(".appsec-tests/*/config.yaml") if err != nil { - return nil, fmt.Errorf("while find appsec-tests config: %s", err) + return nil, fmt.Errorf("while find appsec-tests config: %w", err) } for _, appsecTestConfigPath := range appsecTestConfigs { configFileData := &HubTestItemConfig{} + yamlFile, err := os.ReadFile(appsecTestConfigPath) if err != nil { log.Printf("unable to open appsec test config file '%s': %s", appsecTestConfigPath, err) continue } + err = yaml.Unmarshal(yamlFile, configFileData) if err != nil { - return nil, fmt.Errorf("unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %v", err) } for _, appsecRulesFile := range configFileData.AppsecRules { appsecRuleData := &appsec_rule.CustomRule{} + yamlFile, err := os.ReadFile(appsecRulesFile) if err != nil { log.Printf("unable to open appsec rule '%s': %s", appsecRulesFile, err) } + err = yaml.Unmarshal(yamlFile, appsecRuleData) if err != nil { - return nil, fmt.Errorf("unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %v", err) } + appsecRuleName := appsecRuleData.Name for idx, cov := range coverage { @@ -81,7 +89,7 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { func (h *HubTest) GetParsersCoverage() ([]Coverage, error) { if len(h.HubIndex.GetItemMap(cwhub.PARSERS)) == 0 { - return nil, fmt.Errorf("no parsers in hub index") + return nil, errors.New("no parsers in hub index") } // populate from hub, iterate in alphabetical order @@ -99,13 +107,13 @@ func (h *HubTest) GetParsersCoverage() ([]Coverage, error) { // parser the expressions a-la-oneagain passerts, err := filepath.Glob(".tests/*/parser.assert") if err != nil { - return nil, fmt.Errorf("while find parser asserts : %s", err) + return nil, fmt.Errorf("while find parser asserts: %w", err) } for _, assert := range passerts { file, err := os.Open(assert) if err != nil { - return nil, fmt.Errorf("while reading %s : %s", assert, err) + return nil, fmt.Errorf("while reading %s: %w", assert, err) } scanner := bufio.NewScanner(file) @@ -167,7 +175,7 @@ func (h *HubTest) GetParsersCoverage() ([]Coverage, error) { func (h *HubTest) GetScenariosCoverage() ([]Coverage, error) { if len(h.HubIndex.GetItemMap(cwhub.SCENARIOS)) == 0 { - return nil, fmt.Errorf("no scenarios in hub index") + return nil, errors.New("no scenarios in hub index") } // populate from hub, iterate in alphabetical order @@ -185,13 +193,13 @@ func (h *HubTest) GetScenariosCoverage() ([]Coverage, error) { // parser the expressions a-la-oneagain passerts, err := filepath.Glob(".tests/*/scenario.assert") if err != nil { - return nil, fmt.Errorf("while find scenario asserts : %s", err) + return nil, fmt.Errorf("while find scenario asserts: %w", err) } for _, assert := range passerts { file, err := os.Open(assert) if err != nil { - return nil, fmt.Errorf("while reading %s : %s", assert, err) + return nil, fmt.Errorf("while reading %s: %w", assert, err) } scanner := bufio.NewScanner(file) diff --git a/pkg/hubtest/hubtest.go b/pkg/hubtest/hubtest.go index 6610652f78a..93f5abaa879 100644 --- a/pkg/hubtest/hubtest.go +++ b/pkg/hubtest/hubtest.go @@ -83,7 +83,7 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT } if isAppsecTest { - HubTestPath := filepath.Join(hubPath, "./.appsec-tests/") + HubTestPath := filepath.Join(hubPath, ".appsec-tests") hubIndexFile := filepath.Join(hubPath, ".index.json") local := &csconfig.LocalHubCfg{ @@ -93,9 +93,13 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT InstallDataDir: HubTestPath, } - hub, err := cwhub.NewHub(local, nil, false, nil) + hub, err := cwhub.NewHub(local, nil, nil) if err != nil { - return HubTest{}, fmt.Errorf("unable to load hub: %s", err) + return HubTest{}, err + } + + if err := hub.Load(); err != nil { + return HubTest{}, err } return HubTest{ @@ -115,7 +119,7 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT }, nil } - HubTestPath := filepath.Join(hubPath, "./.tests/") + HubTestPath := filepath.Join(hubPath, ".tests") hubIndexFile := filepath.Join(hubPath, ".index.json") @@ -126,9 +130,13 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT InstallDataDir: HubTestPath, } - hub, err := cwhub.NewHub(local, nil, false, nil) + hub, err := cwhub.NewHub(local, nil, nil) if err != nil { - return HubTest{}, fmt.Errorf("unable to load hub: %s", err) + return HubTest{}, err + } + + if err := hub.Load(); err != nil { + return HubTest{}, err } return HubTest{ diff --git a/pkg/hubtest/hubtest_item.go b/pkg/hubtest/hubtest_item.go index b8a042f071f..bc9c8955d0d 100644 --- a/pkg/hubtest/hubtest_item.go +++ b/pkg/hubtest/hubtest_item.go @@ -1,6 +1,7 @@ package hubtest import ( + "context" "errors" "fmt" "net/url" @@ -10,7 +11,7 @@ import ( "strings" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" @@ -28,7 +29,7 @@ type HubTestItemConfig struct { LogType string `yaml:"log_type,omitempty"` Labels map[string]string `yaml:"labels,omitempty"` IgnoreParsers bool `yaml:"ignore_parsers,omitempty"` // if we test a scenario, we don't want to assert on Parser - OverrideStatics []parser.ExtraField `yaml:"override_statics,omitempty"` //Allow to override statics. Executed before s00 + OverrideStatics []parser.ExtraField `yaml:"override_statics,omitempty"` // Allow to override statics. Executed before s00 } type HubTestItem struct { @@ -110,7 +111,7 @@ func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { err = yaml.Unmarshal(yamlFile, configFileData) if err != nil { - return nil, fmt.Errorf("unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %w", err) } parserAssertFilePath := filepath.Join(testPath, ParserAssertFileName) @@ -200,55 +201,52 @@ func (t *HubTestItem) InstallHub() error { b, err := yaml.Marshal(n) if err != nil { - return fmt.Errorf("unable to marshal overrides: %s", err) + return fmt.Errorf("unable to serialize overrides: %w", err) } tgtFilename := fmt.Sprintf("%s/parsers/s00-raw/00_overrides.yaml", t.RuntimePath) if err := os.WriteFile(tgtFilename, b, os.ModePerm); err != nil { - return fmt.Errorf("unable to write overrides to '%s': %s", tgtFilename, err) + return fmt.Errorf("unable to write overrides to '%s': %w", tgtFilename, err) } } // load installed hub - hub, err := cwhub.NewHub(t.RuntimeHubConfig, nil, false, nil) + hub, err := cwhub.NewHub(t.RuntimeHubConfig, nil, nil) if err != nil { - log.Fatal(err) + return err } - // install data for parsers if needed - ret := hub.GetItemMap(cwhub.PARSERS) - for parserName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", parserName, err) - } + if err := hub.Load(); err != nil { + return err + } - log.Debugf("parser '%s' installed successfully in runtime environment", parserName) + ctx := context.Background() + + // install data for parsers if needed + for _, item := range hub.GetInstalledByType(cwhub.PARSERS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("parser '%s' installed successfully in runtime environment", item.Name) } // install data for scenarios if needed - ret = hub.GetItemMap(cwhub.SCENARIOS) - for scenarioName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", scenarioName, err) - } - - log.Debugf("scenario '%s' installed successfully in runtime environment", scenarioName) + for _, item := range hub.GetInstalledByType(cwhub.SCENARIOS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("scenario '%s' installed successfully in runtime environment", item.Name) } // install data for postoverflows if needed - ret = hub.GetItemMap(cwhub.POSTOVERFLOWS) - for postoverflowName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", postoverflowName, err) - } - - log.Debugf("postoverflow '%s' installed successfully in runtime environment", postoverflowName) + for _, item := range hub.GetInstalledByType(cwhub.POSTOVERFLOWS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("postoverflow '%s' installed successfully in runtime environment", item.Name) } return nil @@ -267,10 +265,10 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { } if err := os.Chdir(testPath); err != nil { - return fmt.Errorf("can't 'cd' to '%s': %s", testPath, err) + return fmt.Errorf("can't 'cd' to '%s': %w", testPath, err) } - //machine add + // machine add cmdArgs := []string{"-c", t.RuntimeConfigFilePath, "machines", "add", "testMachine", "--force", "--auto"} cscliRegisterCmd := exec.Command(t.CscliPath, cmdArgs...) @@ -282,7 +280,7 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { } } - //hardcode bouncer key + // hardcode bouncer key cmdArgs = []string{"-c", t.RuntimeConfigFilePath, "bouncers", "add", "appsectests", "-k", TestBouncerApiKey} cscliBouncerCmd := exec.Command(t.CscliPath, cmdArgs...) @@ -294,13 +292,13 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { } } - //start crowdsec service + // start crowdsec service cmdArgs = []string{"-c", t.RuntimeConfigFilePath} crowdsecDaemon := exec.Command(t.CrowdSecPath, cmdArgs...) crowdsecDaemon.Start() - //wait for the appsec port to be available + // wait for the appsec port to be available if _, err := IsAlive(t.AppSecHost); err != nil { crowdsecLog, err2 := os.ReadFile(crowdsecLogFile) if err2 != nil { @@ -310,27 +308,28 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { log.Errorf("%s\n", string(crowdsecLog)) } - return fmt.Errorf("appsec is down: %s", err) + return fmt.Errorf("appsec is down: %w", err) } // check if the target is available nucleiTargetParsedURL, err := url.Parse(t.NucleiTargetHost) if err != nil { - return fmt.Errorf("unable to parse target '%s': %s", t.NucleiTargetHost, err) + return fmt.Errorf("unable to parse target '%s': %w", t.NucleiTargetHost, err) } nucleiTargetHost := nucleiTargetParsedURL.Host if _, err := IsAlive(nucleiTargetHost); err != nil { - return fmt.Errorf("target is down: %s", err) + return fmt.Errorf("target is down: %w", err) } nucleiConfig := NucleiConfig{ Path: "nuclei", OutputDir: t.RuntimePath, - CmdLineOptions: []string{"-ev", //allow variables from environment - "-nc", //no colors in output - "-dresp", //dump response - "-j", //json output + CmdLineOptions: []string{ + "-ev", // allow variables from environment + "-nc", // no colors in output + "-dresp", // dump response + "-j", // json output }, } @@ -341,6 +340,7 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { t.Success = true } else { log.Errorf("Appsec test %s failed: %s", t.Name, err) + crowdsecLog, err := os.ReadFile(crowdsecLogFile) if err != nil { log.Errorf("unable to read crowdsec log file '%s': %s", crowdsecLogFile, err) @@ -355,6 +355,7 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { t.Success = true } else { log.Errorf("Appsec test %s failed: %s", t.Name, err) + crowdsecLog, err := os.ReadFile(crowdsecLogFile) if err != nil { log.Errorf("unable to read crowdsec log file '%s': %s", crowdsecLogFile, err) @@ -370,39 +371,34 @@ func (t *HubTestItem) RunWithNucleiTemplate() error { return nil } +func createDirs(dirs []string) error { + for _, dir := range dirs { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return fmt.Errorf("unable to create directory '%s': %w", dir, err) + } + } + + return nil +} + func (t *HubTestItem) RunWithLogFile() error { testPath := filepath.Join(t.HubTestPath, t.Name) if _, err := os.Stat(testPath); os.IsNotExist(err) { return fmt.Errorf("test '%s' doesn't exist in '%s', exiting", t.Name, t.HubTestPath) } - currentDir, err := os.Getwd() //xx + currentDir, err := os.Getwd() // xx if err != nil { return fmt.Errorf("can't get current directory: %+v", err) } - // create runtime folder - if err = os.MkdirAll(t.RuntimePath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimePath, err) - } - - // create runtime data folder - if err = os.MkdirAll(t.RuntimeDataPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeDataPath, err) - } - - // create runtime hub folder - if err = os.MkdirAll(t.RuntimeHubPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeHubPath, err) + // create runtime, data, hub folders + if err = createDirs([]string{t.RuntimePath, t.RuntimeDataPath, t.RuntimeHubPath, t.ResultsPath}); err != nil { + return err } if err = Copy(t.HubIndexFile, filepath.Join(t.RuntimeHubPath, ".index.json")); err != nil { - return fmt.Errorf("unable to copy .index.json file in '%s': %s", filepath.Join(t.RuntimeHubPath, ".index.json"), err) - } - - // create results folder - if err = os.MkdirAll(t.ResultsPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.ResultsPath, err) + return fmt.Errorf("unable to copy .index.json file in '%s': %w", filepath.Join(t.RuntimeHubPath, ".index.json"), err) } // copy template config file to runtime folder @@ -424,12 +420,12 @@ func (t *HubTestItem) RunWithLogFile() error { // copy template patterns folder to runtime folder if err = CopyDir(crowdsecPatternsFolder, t.RuntimePatternsPath); err != nil { - return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %s", crowdsecPatternsFolder, t.RuntimePatternsPath, err) + return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %w", crowdsecPatternsFolder, t.RuntimePatternsPath, err) } // install the hub in the runtime folder if err = t.InstallHub(); err != nil { - return fmt.Errorf("unable to install hub in '%s': %s", t.RuntimeHubPath, err) + return fmt.Errorf("unable to install hub in '%s': %w", t.RuntimeHubPath, err) } logFile := t.Config.LogFile @@ -437,12 +433,12 @@ func (t *HubTestItem) RunWithLogFile() error { dsn := fmt.Sprintf("file://%s", logFile) if err = os.Chdir(testPath); err != nil { - return fmt.Errorf("can't 'cd' to '%s': %s", testPath, err) + return fmt.Errorf("can't 'cd' to '%s': %w", testPath, err) } logFileStat, err := os.Stat(logFile) if err != nil { - return fmt.Errorf("unable to stat log file '%s': %s", logFile, err) + return fmt.Errorf("unable to stat log file '%s': %w", logFile, err) } if logFileStat.Size() == 0 { @@ -481,7 +477,7 @@ func (t *HubTestItem) RunWithLogFile() error { } if err := os.Chdir(currentDir); err != nil { - return fmt.Errorf("can't 'cd' to '%s': %s", currentDir, err) + return fmt.Errorf("can't 'cd' to '%s': %w", currentDir, err) } // assert parsers @@ -498,20 +494,20 @@ func (t *HubTestItem) RunWithLogFile() error { assertFileStat, err := os.Stat(t.ParserAssert.File) if err != nil { - return fmt.Errorf("error while stats '%s': %s", t.ParserAssert.File, err) + return fmt.Errorf("error while stats '%s': %w", t.ParserAssert.File, err) } if assertFileStat.Size() == 0 { assertData, err := t.ParserAssert.AutoGenFromFile(t.ParserResultFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } t.ParserAssert.AutoGenAssertData = assertData t.ParserAssert.AutoGenAssert = true } else { if err := t.ParserAssert.AssertFile(t.ParserResultFile); err != nil { - return fmt.Errorf("unable to run assertion on file '%s': %s", t.ParserResultFile, err) + return fmt.Errorf("unable to run assertion on file '%s': %w", t.ParserResultFile, err) } } } @@ -540,20 +536,20 @@ func (t *HubTestItem) RunWithLogFile() error { assertFileStat, err := os.Stat(t.ScenarioAssert.File) if err != nil { - return fmt.Errorf("error while stats '%s': %s", t.ScenarioAssert.File, err) + return fmt.Errorf("error while stats '%s': %w", t.ScenarioAssert.File, err) } if assertFileStat.Size() == 0 { assertData, err := t.ScenarioAssert.AutoGenFromFile(t.ScenarioResultFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } t.ScenarioAssert.AutoGenAssertData = assertData t.ScenarioAssert.AutoGenAssert = true } else { if err := t.ScenarioAssert.AssertFile(t.ScenarioResultFile); err != nil { - return fmt.Errorf("unable to run assertion on file '%s': %s", t.ScenarioResultFile, err) + return fmt.Errorf("unable to run assertion on file '%s': %w", t.ScenarioResultFile, err) } } } @@ -575,28 +571,13 @@ func (t *HubTestItem) Run() error { t.Success = false t.ErrorsList = make([]string, 0) - // create runtime folder - if err = os.MkdirAll(t.RuntimePath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimePath, err) - } - - // create runtime data folder - if err = os.MkdirAll(t.RuntimeDataPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeDataPath, err) - } - - // create runtime hub folder - if err = os.MkdirAll(t.RuntimeHubPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimeHubPath, err) + // create runtime, data, hub, result folders + if err = createDirs([]string{t.RuntimePath, t.RuntimeDataPath, t.RuntimeHubPath, t.ResultsPath}); err != nil { + return err } if err = Copy(t.HubIndexFile, filepath.Join(t.RuntimeHubPath, ".index.json")); err != nil { - return fmt.Errorf("unable to copy .index.json file in '%s': %s", filepath.Join(t.RuntimeHubPath, ".index.json"), err) - } - - // create results folder - if err = os.MkdirAll(t.ResultsPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", t.ResultsPath, err) + return fmt.Errorf("unable to copy .index.json file in '%s': %w", filepath.Join(t.RuntimeHubPath, ".index.json"), err) } // copy template config file to runtime folder @@ -618,7 +599,7 @@ func (t *HubTestItem) Run() error { // copy template patterns folder to runtime folder if err = CopyDir(crowdsecPatternsFolder, t.RuntimePatternsPath); err != nil { - return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %s", crowdsecPatternsFolder, t.RuntimePatternsPath, err) + return fmt.Errorf("unable to copy 'patterns' from '%s' to '%s': %w", crowdsecPatternsFolder, t.RuntimePatternsPath, err) } // create the appsec-configs dir @@ -626,7 +607,7 @@ func (t *HubTestItem) Run() error { return fmt.Errorf("unable to create folder '%s': %+v", t.RuntimePath, err) } - //if it's an appsec rule test, we need acquis and appsec profile + // if it's an appsec rule test, we need acquis and appsec profile if len(t.Config.AppsecRules) > 0 { // copy template acquis file to runtime folder log.Debugf("copying %s to %s", t.TemplateAcquisPath, t.RuntimeAcquisFilePath) @@ -640,22 +621,21 @@ func (t *HubTestItem) Run() error { if err = Copy(t.TemplateAppsecProfilePath, filepath.Join(t.RuntimePath, "appsec-configs", "config.yaml")); err != nil { return fmt.Errorf("unable to copy '%s' to '%s': %v", t.TemplateAppsecProfilePath, filepath.Join(t.RuntimePath, "appsec-configs", "config.yaml"), err) } - } else { //otherwise we drop a blank acquis file + } else { // otherwise we drop a blank acquis file if err = os.WriteFile(t.RuntimeAcquisFilePath, []byte(""), os.ModePerm); err != nil { - return fmt.Errorf("unable to write blank acquis file '%s': %s", t.RuntimeAcquisFilePath, err) + return fmt.Errorf("unable to write blank acquis file '%s': %w", t.RuntimeAcquisFilePath, err) } } // install the hub in the runtime folder if err = t.InstallHub(); err != nil { - return fmt.Errorf("unable to install hub in '%s': %s", t.RuntimeHubPath, err) + return fmt.Errorf("unable to install hub in '%s': %w", t.RuntimeHubPath, err) } if t.Config.LogFile != "" { return t.RunWithLogFile() } else if t.Config.NucleiTemplate != "" { return t.RunWithNucleiTemplate() - } else { - return fmt.Errorf("log file or nuclei template must be set in '%s'", t.Name) } + return fmt.Errorf("log file or nuclei template must be set in '%s'", t.Name) } diff --git a/pkg/hubtest/nucleirunner.go b/pkg/hubtest/nucleirunner.go index 0bf2013dd8d..32c81eb64d8 100644 --- a/pkg/hubtest/nucleirunner.go +++ b/pkg/hubtest/nucleirunner.go @@ -42,11 +42,11 @@ func (nc *NucleiConfig) RunNucleiTemplate(testName string, templatePath string, err := cmd.Run() - if err := os.WriteFile(outputPrefix+"_stdout.txt", out.Bytes(), 0644); err != nil { + if err := os.WriteFile(outputPrefix+"_stdout.txt", out.Bytes(), 0o644); err != nil { log.Warningf("Error writing stdout: %s", err) } - if err := os.WriteFile(outputPrefix+"_stderr.txt", outErr.Bytes(), 0644); err != nil { + if err := os.WriteFile(outputPrefix+"_stderr.txt", outErr.Bytes(), 0o644); err != nil { log.Warningf("Error writing stderr: %s", err) } @@ -56,7 +56,7 @@ func (nc *NucleiConfig) RunNucleiTemplate(testName string, templatePath string, log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") return err - } else if len(out.String()) == 0 { + } else if out.String() == "" { log.Warningf("Stdout saved to %s", outputPrefix+"_stdout.txt") log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") diff --git a/pkg/hubtest/parser.go b/pkg/hubtest/parser.go index b8dcdb8b1d0..31ff459ca77 100644 --- a/pkg/hubtest/parser.go +++ b/pkg/hubtest/parser.go @@ -9,89 +9,86 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installParserItem(hubParser *cwhub.Item) error { - parserSource, err := filepath.Abs(filepath.Join(t.HubPath, hubParser.RemotePath)) +func (t *HubTestItem) installParserItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", parserSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - parserFileName := filepath.Base(parserSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/parsers/s00-raw/crowdsecurity/ - hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubParser.RemotePath)) + hubDirParserDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/parsers/s00-raw/ - parserDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, hubParser.Stage) + itemTypeDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, item.Stage) - if err := os.MkdirAll(hubDirParserDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirParserDest, err) - } - - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) + if err := createDirs([]string{hubDirParserDest, itemTypeDirDest}); err != nil { + return err } // runtime/hub/parsers/s00-raw/crowdsecurity/syslog-logs.yaml - hubDirParserPath := filepath.Join(hubDirParserDest, parserFileName) - if err := Copy(parserSource, hubDirParserPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", parserSource, hubDirParserPath, err) + hubDirParserPath := filepath.Join(hubDirParserDest, sourceFilename) + if err := Copy(sourcePath, hubDirParserPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirParserPath, err) } // runtime/parsers/s00-raw/syslog-logs.yaml - parserDirParserPath := filepath.Join(parserDirDest, parserFileName) + parserDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirParserPath, parserDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink parser '%s' to '%s': %s", hubDirParserPath, parserDirParserPath, err) + return fmt.Errorf("unable to symlink parser '%s' to '%s': %w", hubDirParserPath, parserDirParserPath, err) } } return nil } -func (t *HubTestItem) installParserCustom(parser string) error { - customParserExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom parser - customParserPath := filepath.Join(customPath, parser) - if _, err := os.Stat(customParserPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("parser '%s' doesn't exist in the hub and doesn't appear to be a custom one.", parser) - } +func (t *HubTestItem) installParserCustomFrom(parser string, customPath string) (bool, error) { + // we check if its a custom parser + customParserPath := filepath.Join(customPath, parser) + if _, err := os.Stat(customParserPath); os.IsNotExist(err) { + return false, nil + } - customParserPathSplit, customParserName := filepath.Split(customParserPath) - // because path is parsers///parser.yaml and we wan't the stage - splittedPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) - customParserStage := splittedPath[len(splittedPath)-3] + customParserPathSplit, customParserName := filepath.Split(customParserPath) + // because path is parsers///parser.yaml and we wan't the stage + splitPath := strings.Split(customParserPathSplit, string(os.PathSeparator)) + customParserStage := splitPath[len(splitPath)-3] - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("parsers/%s", customParserStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) + } - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' extracted from '%s' doesn't exist in the hub", customParserStage, hubStagePath) - } + stageDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } - parserDirDest := fmt.Sprintf("%s/parsers/%s/", t.RuntimePath, customParserStage) - if err := os.MkdirAll(parserDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", parserDirDest, err) - } + customParserDest := filepath.Join(stageDirDest, customParserName) + // if path to parser exist, copy it + if err := Copy(customParserPath, customParserDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customParserPath, customParserDest, err) + } - customParserDest := filepath.Join(parserDirDest, customParserName) - // if path to parser exist, copy it - if err := Copy(customParserPath, customParserDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customParserPath, customParserDest, err) + return true, nil +} + +func (t *HubTestItem) installParserCustom(parser string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installParserCustomFrom(parser, customPath) + if err != nil { + return err } - customParserExist = true - break - } - if !customParserExist { - return fmt.Errorf("couldn't find custom parser '%s' in the following location: %+v", parser, t.CustomItemsLocation) + if found { + return nil + } } - return nil + return fmt.Errorf("couldn't find custom parser '%s' in the following locations: %+v", parser, t.CustomItemsLocation) } func (t *HubTestItem) installParser(name string) error { diff --git a/pkg/hubtest/parser_assert.go b/pkg/hubtest/parser_assert.go index 7eec8e535e5..be4fdbdb5e6 100644 --- a/pkg/hubtest/parser_assert.go +++ b/pkg/hubtest/parser_assert.go @@ -2,17 +2,19 @@ package hubtest import ( "bufio" + "errors" "fmt" "os" "strings" - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/maptools" "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/go-cs-lib/maptools" ) type AssertFail struct { @@ -69,13 +71,12 @@ func (p *ParserAssert) LoadTest(filename string) error { func (p *ParserAssert) AssertFile(testFile string) error { file, err := os.Open(p.File) - if err != nil { - return fmt.Errorf("failed to open") + return errors.New("failed to open") } if err := p.LoadTest(testFile); err != nil { - return fmt.Errorf("unable to load parser dump file '%s': %s", testFile, err) + return fmt.Errorf("unable to load parser dump file '%s': %w", testFile, err) } scanner := bufio.NewScanner(file) @@ -107,6 +108,7 @@ func (p *ParserAssert) AssertFile(testFile string) error { } match := variableRE.FindStringSubmatch(scanner.Text()) + var variable string if len(match) == 0 { @@ -127,7 +129,7 @@ func (p *ParserAssert) AssertFile(testFile string) error { continue } - //fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) + // fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) } file.Close() @@ -135,7 +137,7 @@ func (p *ParserAssert) AssertFile(testFile string) error { if p.NbAssert == 0 { assertData, err := p.AutoGenFromFile(testFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } p.AutoGenAssertData = assertData @@ -150,8 +152,8 @@ func (p *ParserAssert) AssertFile(testFile string) error { } func (p *ParserAssert) RunExpression(expression string) (interface{}, error) { - //debug doesn't make much sense with the ability to evaluate "on the fly" - //var debugFilter *exprhelpers.ExprDebugger + // debug doesn't make much sense with the ability to evaluate "on the fly" + // var debugFilter *exprhelpers.ExprDebugger var output interface{} env := map[string]interface{}{"results": *p.TestData} @@ -162,7 +164,7 @@ func (p *ParserAssert) RunExpression(expression string) (interface{}, error) { return output, err } - //dump opcode in trace level + // dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) output, err = expr.Run(runtimeFilter, env) @@ -183,7 +185,6 @@ func (p *ParserAssert) EvalExpression(expression string) (string, error) { } ret, err := yaml.Marshal(output) - if err != nil { return "", err } @@ -213,16 +214,16 @@ func Escape(val string) string { } func (p *ParserAssert) AutoGenParserAssert() string { - //attempt to autogen parser asserts + // attempt to autogen parser asserts ret := fmt.Sprintf("len(results) == %d\n", len(*p.TestData)) - //sort map keys for consistent order + // sort map keys for consistent order stages := maptools.SortedKeys(*p.TestData) for _, stage := range stages { parsers := (*p.TestData)[stage] - //sort map keys for consistent order + // sort map keys for consistent order pnames := maptools.SortedKeys(parsers) for _, parser := range pnames { diff --git a/pkg/hubtest/postoverflow.go b/pkg/hubtest/postoverflow.go index d5d43ddc742..65fd0bfbc5d 100644 --- a/pkg/hubtest/postoverflow.go +++ b/pkg/hubtest/postoverflow.go @@ -9,88 +9,86 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installPostoverflowItem(hubPostOverflow *cwhub.Item) error { - postoverflowSource, err := filepath.Abs(filepath.Join(t.HubPath, hubPostOverflow.RemotePath)) +func (t *HubTestItem) installPostoverflowItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path of '%s': %s", postoverflowSource, err) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - postoverflowFileName := filepath.Base(postoverflowSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/postoverflows/s00-enrich/crowdsecurity/ - hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubPostOverflow.RemotePath)) + hubDirPostoverflowDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/postoverflows/s00-enrich - postoverflowDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, hubPostOverflow.Stage) + itemTypeDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, item.Stage) - if err := os.MkdirAll(hubDirPostoverflowDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirPostoverflowDest, err) - } - - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) + if err := createDirs([]string{hubDirPostoverflowDest, itemTypeDirDest}); err != nil { + return err } // runtime/hub/postoverflows/s00-enrich/crowdsecurity/rdns.yaml - hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, postoverflowFileName) - if err := Copy(postoverflowSource, hubDirPostoverflowPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", postoverflowSource, hubDirPostoverflowPath, err) + hubDirPostoverflowPath := filepath.Join(hubDirPostoverflowDest, sourceFilename) + if err := Copy(sourcePath, hubDirPostoverflowPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirPostoverflowPath, err) } // runtime/postoverflows/s00-enrich/rdns.yaml - postoverflowDirParserPath := filepath.Join(postoverflowDirDest, postoverflowFileName) + postoverflowDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirPostoverflowPath, postoverflowDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %s", hubDirPostoverflowPath, postoverflowDirParserPath, err) + return fmt.Errorf("unable to symlink postoverflow '%s' to '%s': %w", hubDirPostoverflowPath, postoverflowDirParserPath, err) } } return nil } -func (t *HubTestItem) installPostoverflowCustom(postoverflow string) error { - customPostoverflowExist := false - for _, customPath := range t.CustomItemsLocation { - // we check if its a custom postoverflow - customPostOverflowPath := filepath.Join(customPath, postoverflow) - if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("postoverflow '%s' doesn't exist in the hub and doesn't appear to be a custom one.", postoverflow) - } +func (t *HubTestItem) installPostoverflowCustomFrom(postoverflow string, customPath string) (bool, error) { + // we check if its a custom postoverflow + customPostOverflowPath := filepath.Join(customPath, postoverflow) + if _, err := os.Stat(customPostOverflowPath); os.IsNotExist(err) { + return false, nil + } - customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") - customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] - // because path is postoverflows///parser.yaml and we wan't the stage - customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] + customPostOverflowPathSplit := strings.Split(customPostOverflowPath, "/") + customPostoverflowName := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-1] + // because path is postoverflows///parser.yaml and we wan't the stage + customPostoverflowStage := customPostOverflowPathSplit[len(customPostOverflowPathSplit)-3] - // check if stage exist - hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) + // check if stage exist + hubStagePath := filepath.Join(t.HubPath, fmt.Sprintf("postoverflows/%s", customPostoverflowStage)) + if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { + return false, fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) + } - if _, err := os.Stat(hubStagePath); os.IsNotExist(err) { - continue - //return fmt.Errorf("stage '%s' from extracted '%s' doesn't exist in the hub", customPostoverflowStage, hubStagePath) - } + stageDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) + if err := os.MkdirAll(stageDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", stageDirDest, err) + } - postoverflowDirDest := fmt.Sprintf("%s/postoverflows/%s/", t.RuntimePath, customPostoverflowStage) - if err := os.MkdirAll(postoverflowDirDest, os.ModePerm); err != nil { - continue - //return fmt.Errorf("unable to create folder '%s': %s", postoverflowDirDest, err) + customPostoverflowDest := filepath.Join(stageDirDest, customPostoverflowName) + // if path to postoverflow exist, copy it + if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { + return false, fmt.Errorf("unable to copy custom parser '%s' to '%s': %w", customPostOverflowPath, customPostoverflowDest, err) + } + + return true, nil +} + +func (t *HubTestItem) installPostoverflowCustom(postoverflow string) error { + for _, customPath := range t.CustomItemsLocation { + found, err := t.installPostoverflowCustomFrom(postoverflow, customPath) + if err != nil { + return err } - customPostoverflowDest := filepath.Join(postoverflowDirDest, customPostoverflowName) - // if path to postoverflow exist, copy it - if err := Copy(customPostOverflowPath, customPostoverflowDest); err != nil { - continue - //return fmt.Errorf("unable to copy custom parser '%s' to '%s': %s", customPostOverflowPath, customPostoverflowDest, err) + if found { + return nil } - customPostoverflowExist = true - break - } - if !customPostoverflowExist { - return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom postoverflow '%s' in the following location: %+v", postoverflow, t.CustomItemsLocation) } func (t *HubTestItem) installPostoverflow(name string) error { diff --git a/pkg/hubtest/regexp.go b/pkg/hubtest/regexp.go index f9165eae3d1..8b2fcc928dd 100644 --- a/pkg/hubtest/regexp.go +++ b/pkg/hubtest/regexp.go @@ -5,7 +5,7 @@ import ( ) var ( - variableRE = regexp.MustCompile(`(?P[^ =]+) == .*`) - parserResultRE = regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) + variableRE = regexp.MustCompile(`(?P[^ =]+) == .*`) + parserResultRE = regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) scenarioResultRE = regexp.MustCompile(`^results\[[0-9]+\].Overflow.Alert.GetScenario\(\) == "(?P[^"]+)"`) ) diff --git a/pkg/hubtest/scenario.go b/pkg/hubtest/scenario.go index eaa831d8013..7f61e48accf 100644 --- a/pkg/hubtest/scenario.go +++ b/pkg/hubtest/scenario.go @@ -8,74 +8,76 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func (t *HubTestItem) installScenarioItem(hubScenario *cwhub.Item) error { - scenarioSource, err := filepath.Abs(filepath.Join(t.HubPath, hubScenario.RemotePath)) +func (t *HubTestItem) installScenarioItem(item *cwhub.Item) error { + sourcePath, err := filepath.Abs(filepath.Join(t.HubPath, item.RemotePath)) if err != nil { - return fmt.Errorf("can't get absolute path to: %s", scenarioSource) + return fmt.Errorf("can't get absolute path of '%s': %w", sourcePath, err) } - scenarioFileName := filepath.Base(scenarioSource) + sourceFilename := filepath.Base(sourcePath) // runtime/hub/scenarios/crowdsecurity/ - hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(hubScenario.RemotePath)) + hubDirScenarioDest := filepath.Join(t.RuntimeHubPath, filepath.Dir(item.RemotePath)) // runtime/parsers/scenarios/ - scenarioDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) - if err := os.MkdirAll(hubDirScenarioDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", hubDirScenarioDest, err) - } - - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) + if err := createDirs([]string{hubDirScenarioDest, itemTypeDirDest}); err != nil { + return err } // runtime/hub/scenarios/crowdsecurity/ssh-bf.yaml - hubDirScenarioPath := filepath.Join(hubDirScenarioDest, scenarioFileName) - if err := Copy(scenarioSource, hubDirScenarioPath); err != nil { - return fmt.Errorf("unable to copy '%s' to '%s': %s", scenarioSource, hubDirScenarioPath, err) + hubDirScenarioPath := filepath.Join(hubDirScenarioDest, sourceFilename) + if err := Copy(sourcePath, hubDirScenarioPath); err != nil { + return fmt.Errorf("unable to copy '%s' to '%s': %w", sourcePath, hubDirScenarioPath, err) } // runtime/scenarios/ssh-bf.yaml - scenarioDirParserPath := filepath.Join(scenarioDirDest, scenarioFileName) + scenarioDirParserPath := filepath.Join(itemTypeDirDest, sourceFilename) if err := os.Symlink(hubDirScenarioPath, scenarioDirParserPath); err != nil { if !os.IsExist(err) { - return fmt.Errorf("unable to symlink scenario '%s' to '%s': %s", hubDirScenarioPath, scenarioDirParserPath, err) + return fmt.Errorf("unable to symlink scenario '%s' to '%s': %w", hubDirScenarioPath, scenarioDirParserPath, err) } } return nil } +func (t *HubTestItem) installScenarioCustomFrom(scenario string, customPath string) (bool, error) { + // we check if its a custom scenario + customScenarioPath := filepath.Join(customPath, scenario) + if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { + return false, nil + } + + itemTypeDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) + if err := os.MkdirAll(itemTypeDirDest, os.ModePerm); err != nil { + return false, fmt.Errorf("unable to create folder '%s': %w", itemTypeDirDest, err) + } + + scenarioFileName := filepath.Base(customScenarioPath) + + scenarioFileDest := filepath.Join(itemTypeDirDest, scenarioFileName) + if err := Copy(customScenarioPath, scenarioFileDest); err != nil { + return false, fmt.Errorf("unable to copy scenario from '%s' to '%s': %w", customScenarioPath, scenarioFileDest, err) + } + + return true, nil +} + func (t *HubTestItem) installScenarioCustom(scenario string) error { - customScenarioExist := false for _, customPath := range t.CustomItemsLocation { - // we check if its a custom scenario - customScenarioPath := filepath.Join(customPath, scenario) - if _, err := os.Stat(customScenarioPath); os.IsNotExist(err) { - continue - //return fmt.Errorf("scenarios '%s' doesn't exist in the hub and doesn't appear to be a custom one.", scenario) + found, err := t.installScenarioCustomFrom(scenario, customPath) + if err != nil { + return err } - scenarioDirDest := fmt.Sprintf("%s/scenarios/", t.RuntimePath) - if err := os.MkdirAll(scenarioDirDest, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %s", scenarioDirDest, err) + if found { + return nil } - - scenarioFileName := filepath.Base(customScenarioPath) - scenarioFileDest := filepath.Join(scenarioDirDest, scenarioFileName) - if err := Copy(customScenarioPath, scenarioFileDest); err != nil { - continue - //return fmt.Errorf("unable to copy scenario from '%s' to '%s': %s", customScenarioPath, scenarioFileDest, err) - } - customScenarioExist = true - break - } - if !customScenarioExist { - return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) } - return nil + return fmt.Errorf("couldn't find custom scenario '%s' in the following location: %+v", scenario, t.CustomItemsLocation) } func (t *HubTestItem) installScenario(name string) error { diff --git a/pkg/hubtest/scenario_assert.go b/pkg/hubtest/scenario_assert.go index 5195b814ef3..f32abf9e110 100644 --- a/pkg/hubtest/scenario_assert.go +++ b/pkg/hubtest/scenario_assert.go @@ -2,15 +2,16 @@ package hubtest import ( "bufio" + "errors" "fmt" "io" "os" "sort" "strings" - "github.com/antonmedv/expr" + "github.com/expr-lang/expr" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" @@ -77,13 +78,12 @@ func (s *ScenarioAssert) LoadTest(filename string, bucketpour string) error { func (s *ScenarioAssert) AssertFile(testFile string) error { file, err := os.Open(s.File) - if err != nil { - return fmt.Errorf("failed to open") + return errors.New("failed to open") } if err := s.LoadTest(testFile, ""); err != nil { - return fmt.Errorf("unable to load parser dump file '%s': %s", testFile, err) + return fmt.Errorf("unable to load parser dump file '%s': %w", testFile, err) } scanner := bufio.NewScanner(file) @@ -134,7 +134,7 @@ func (s *ScenarioAssert) AssertFile(testFile string) error { continue } - //fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) + // fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text()) } file.Close() @@ -142,7 +142,7 @@ func (s *ScenarioAssert) AssertFile(testFile string) error { if s.NbAssert == 0 { assertData, err := s.AutoGenFromFile(testFile) if err != nil { - return fmt.Errorf("couldn't generate assertion: %s", err) + return fmt.Errorf("couldn't generate assertion: %w", err) } s.AutoGenAssertData = assertData @@ -157,8 +157,8 @@ func (s *ScenarioAssert) AssertFile(testFile string) error { } func (s *ScenarioAssert) RunExpression(expression string) (interface{}, error) { - //debug doesn't make much sense with the ability to evaluate "on the fly" - //var debugFilter *exprhelpers.ExprDebugger + // debug doesn't make much sense with the ability to evaluate "on the fly" + // var debugFilter *exprhelpers.ExprDebugger var output interface{} env := map[string]interface{}{"results": *s.TestData} @@ -171,7 +171,7 @@ func (s *ScenarioAssert) RunExpression(expression string) (interface{}, error) { // log.Warningf("Failed building debugher for %s : %s", assert, err) // } - //dump opcode in trace level + // dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *s.TestData}) diff --git a/pkg/hubtest/utils.go b/pkg/hubtest/utils.go index 9009d0dddec..b42a73461f3 100644 --- a/pkg/hubtest/utils.go +++ b/pkg/hubtest/utils.go @@ -1,6 +1,7 @@ package hubtest import ( + "errors" "fmt" "net" "os" @@ -56,7 +57,7 @@ func checkPathNotContained(path string, subpath string) error { for { if current == absPath { - return fmt.Errorf("cannot copy a folder onto itself") + return errors.New("cannot copy a folder onto itself") } up := filepath.Dir(current) @@ -87,10 +88,10 @@ func CopyDir(src string, dest string) error { } if !file.IsDir() { - return fmt.Errorf("Source " + file.Name() + " is not a directory!") + return errors.New("Source " + file.Name() + " is not a directory!") } - err = os.MkdirAll(dest, 0755) + err = os.MkdirAll(dest, 0o755) if err != nil { return err } diff --git a/pkg/leakybucket/bayesian.go b/pkg/leakybucket/bayesian.go index b8d20a488f9..357d51f597b 100644 --- a/pkg/leakybucket/bayesian.go +++ b/pkg/leakybucket/bayesian.go @@ -3,8 +3,9 @@ package leakybucket import ( "fmt" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -109,7 +110,7 @@ func (b *BayesianEvent) bayesianUpdate(c *BayesianBucket, msg types.Event, l *Le l.logger.Debugf("running condition expression: %s", b.rawCondition.ConditionalFilterName) ret, err := exprhelpers.Run(b.conditionalFilterRuntime, map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}, l.logger, l.BucketConfig.Debug) if err != nil { - return fmt.Errorf("unable to run conditional filter: %s", err) + return fmt.Errorf("unable to run conditional filter: %w", err) } l.logger.Tracef("bayesian bucket expression %s returned : %v", b.rawCondition.ConditionalFilterName, ret) diff --git a/pkg/leakybucket/bucket.go b/pkg/leakybucket/bucket.go index afb5377aa4f..e981551af8f 100644 --- a/pkg/leakybucket/bucket.go +++ b/pkg/leakybucket/bucket.go @@ -6,15 +6,16 @@ import ( "sync/atomic" "time" - "github.com/crowdsecurity/go-cs-lib/trace" - - "github.com/crowdsecurity/crowdsec/pkg/time/rate" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/davecgh/go-spew/spew" "github.com/mohae/deepcopy" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/time/rate" + "github.com/crowdsecurity/crowdsec/pkg/types" ) // those constants are now defined in types/constants diff --git a/pkg/leakybucket/buckets_test.go b/pkg/leakybucket/buckets_test.go index 9e7205e8613..1da906cb555 100644 --- a/pkg/leakybucket/buckets_test.go +++ b/pkg/leakybucket/buckets_test.go @@ -16,6 +16,7 @@ import ( "github.com/davecgh/go-spew/spew" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" yaml "gopkg.in/yaml.v2" @@ -45,14 +46,15 @@ func TestBucket(t *testing.T) { InstallDataDir: testdata, } - hub, err := cwhub.NewHub(hubCfg, nil, false, nil) - if err != nil { - t.Fatalf("failed to init hub: %s", err) - } + hub, err := cwhub.NewHub(hubCfg, nil, nil) + require.NoError(t, err) + + err = hub.Load() + require.NoError(t, err) err = exprhelpers.Init(nil) if err != nil { - log.Fatalf("exprhelpers init failed: %s", err) + t.Fatalf("exprhelpers init failed: %s", err) } if envSetting != "" { @@ -61,25 +63,31 @@ func TestBucket(t *testing.T) { } } else { wg := new(sync.WaitGroup) + fds, err := os.ReadDir(testdata) if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { if fd.Name() == "hub" { continue } + fname := filepath.Join(testdata, fd.Name()) log.Infof("Running test on %s", fname) tomb.Go(func() error { wg.Add(1) defer wg.Done() + if err := testOneBucket(t, hub, fname, tomb); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } + return nil }) } + wg.Wait() } } @@ -88,16 +96,16 @@ func TestBucket(t *testing.T) { // we want to avoid the death of the tomb because all existing buckets have been destroyed. func watchTomb(tomb *tomb.Tomb) { for { - if tomb.Alive() == false { + if !tomb.Alive() { log.Warning("Tomb is dead") break } + time.Sleep(100 * time.Millisecond) } } func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) error { - var ( holders []BucketFactory @@ -105,9 +113,9 @@ func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) er stagecfg string stages []parser.Stagefile err error - buckets *Buckets ) - buckets = NewBuckets() + + buckets := NewBuckets() /*load the scenarios*/ stagecfg = dir + "/scenarios.yaml" @@ -117,51 +125,59 @@ func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) er tmpl, err := template.New("test").Parse(string(stagefiles)) if err != nil { - return fmt.Errorf("failed to parse template %s : %s", stagefiles, err) + return fmt.Errorf("failed to parse template %s: %w", stagefiles, err) } + var out bytes.Buffer + err = tmpl.Execute(&out, map[string]string{"TestDirectory": dir}) if err != nil { panic(err) } + if err := yaml.UnmarshalStrict(out.Bytes(), &stages); err != nil { - log.Fatalf("failed unmarshaling %s : %s", stagecfg, err) + t.Fatalf("failed to parse %s : %s", stagecfg, err) } + files := []string{} for _, x := range stages { files = append(files, x.Filename) } cscfg := &csconfig.CrowdsecServiceCfg{} + holders, response, err := LoadBuckets(cscfg, hub, files, tomb, buckets, false) if err != nil { t.Fatalf("failed loading bucket : %s", err) } + tomb.Go(func() error { watchTomb(tomb) return nil }) + if !testFile(t, filepath.Join(dir, "test.json"), filepath.Join(dir, "in-buckets_state.json"), holders, response, buckets) { return fmt.Errorf("tests from %s failed", dir) } + return nil } func testFile(t *testing.T, file string, bs string, holders []BucketFactory, response chan types.Event, buckets *Buckets) bool { - var results []types.Event var dump bool - //should we restore + // should we restore if _, err := os.Stat(bs); err == nil { dump = true + if err := LoadBucketsState(bs, buckets, holders); err != nil { t.Fatalf("Failed to load bucket state : %s", err) } } /* now we can load the test files */ - //process the yaml + // process the yaml yamlFile, err := os.Open(file) if err != nil { t.Errorf("yamlFile.Get err #%v ", err) @@ -183,9 +199,11 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res //just to avoid any race during ingestion of funny scenarios time.Sleep(50 * time.Millisecond) var ts time.Time + if err := ts.UnmarshalText([]byte(in.MarshaledTime)); err != nil { - t.Fatalf("Failed to unmarshal time from input event : %s", err) + t.Fatalf("Failed to parse time from input event : %s", err) } + if latest_ts.IsZero() { latest_ts = ts } else if ts.After(latest_ts) { @@ -194,10 +212,12 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res in.ExpectMode = types.TIMEMACHINE log.Infof("Buckets input : %s", spew.Sdump(in)) + ok, err := PourItemToHolders(in, holders, buckets) if err != nil { t.Fatalf("Failed to pour : %s", err) } + if !ok { log.Warning("Event wasn't poured") } diff --git a/pkg/leakybucket/conditional.go b/pkg/leakybucket/conditional.go index 5ff69e60a26..a203a639743 100644 --- a/pkg/leakybucket/conditional.go +++ b/pkg/leakybucket/conditional.go @@ -4,8 +4,9 @@ import ( "fmt" "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index 85eee89d933..b8310b8cb17 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -11,9 +11,9 @@ import ( "sync" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/goombaio/namegenerator" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -34,115 +34,170 @@ type BucketFactory struct { Author string `yaml:"author"` Description string `yaml:"description"` References []string `yaml:"references"` - Type string `yaml:"type"` //Type can be : leaky, counter, trigger. It determines the main bucket characteristics - Name string `yaml:"name"` //Name of the bucket, used later in log and user-messages. Should be unique - Capacity int `yaml:"capacity"` //Capacity is applicable to leaky buckets and determines the "burst" capacity - LeakSpeed string `yaml:"leakspeed"` //Leakspeed is a float representing how many events per second leak out of the bucket - Duration string `yaml:"duration"` //Duration allows 'counter' buckets to have a fixed life-time - Filter string `yaml:"filter"` //Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct - GroupBy string `yaml:"groupby,omitempty"` //groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip - Distinct string `yaml:"distinct"` //Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) - Debug bool `yaml:"debug"` //Debug, when set to true, will enable debugging for _this_ scenario specifically - Labels map[string]interface{} `yaml:"labels"` //Labels is K:V list aiming at providing context the overflow - Blackhole string `yaml:"blackhole,omitempty"` //Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration - logger *log.Entry `yaml:"-"` //logger is bucket-specific logger (used by Debug as well) - Reprocess bool `yaml:"reprocess"` //Reprocess, if true, will for the bucket to be re-injected into processing chain - CacheSize int `yaml:"cache_size"` //CacheSize, if > 0, limits the size of in-memory cache of the bucket - Profiling bool `yaml:"profiling"` //Profiling, if true, will make the bucket record pours/overflows/etc. - OverflowFilter string `yaml:"overflow_filter"` //OverflowFilter if present, is a filter that must return true for the overflow to go through - ConditionalOverflow string `yaml:"condition"` //condition if present, is an expression that must return true for the bucket to overflow + Type string `yaml:"type"` // Type can be : leaky, counter, trigger. It determines the main bucket characteristics + Name string `yaml:"name"` // Name of the bucket, used later in log and user-messages. Should be unique + Capacity int `yaml:"capacity"` // Capacity is applicable to leaky buckets and determines the "burst" capacity + LeakSpeed string `yaml:"leakspeed"` // Leakspeed is a float representing how many events per second leak out of the bucket + Duration string `yaml:"duration"` // Duration allows 'counter' buckets to have a fixed life-time + Filter string `yaml:"filter"` // Filter is an expr that determines if an event is elligible for said bucket. Filter is evaluated against the Event struct + GroupBy string `yaml:"groupby,omitempty"` // groupy is an expr that allows to determine the partitions of the bucket. A common example is the source_ip + Distinct string `yaml:"distinct"` // Distinct, when present, adds a `Pour()` processor that will only pour uniq items (based on distinct expr result) + Debug bool `yaml:"debug"` // Debug, when set to true, will enable debugging for _this_ scenario specifically + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow + Blackhole string `yaml:"blackhole,omitempty"` // Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration + logger *log.Entry // logger is bucket-specific logger (used by Debug as well) + Reprocess bool `yaml:"reprocess"` // Reprocess, if true, will for the bucket to be re-injected into processing chain + CacheSize int `yaml:"cache_size"` // CacheSize, if > 0, limits the size of in-memory cache of the bucket + Profiling bool `yaml:"profiling"` // Profiling, if true, will make the bucket record pours/overflows/etc. + OverflowFilter string `yaml:"overflow_filter"` // OverflowFilter if present, is a filter that must return true for the overflow to go through + ConditionalOverflow string `yaml:"condition"` // condition if present, is an expression that must return true for the bucket to overflow BayesianPrior float32 `yaml:"bayesian_prior"` BayesianThreshold float32 `yaml:"bayesian_threshold"` - BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` //conditions for the bayesian bucket - ScopeType types.ScopeType `yaml:"scope,omitempty"` //to enforce a different remediation than blocking an IP. Will default this to IP + BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` // conditions for the bayesian bucket + ScopeType types.ScopeType `yaml:"scope,omitempty"` // to enforce a different remediation than blocking an IP. Will default this to IP BucketName string `yaml:"-"` Filename string `yaml:"-"` RunTimeFilter *vm.Program `json:"-"` RunTimeGroupBy *vm.Program `json:"-"` Data []*types.DataSource `yaml:"data,omitempty"` DataDir string `yaml:"-"` - CancelOnFilter string `yaml:"cancel_on,omitempty"` //a filter that, if matched, kills the bucket - leakspeed time.Duration //internal representation of `Leakspeed` - duration time.Duration //internal representation of `Duration` - ret chan types.Event //the bucket-specific output chan for overflows - processors []Processor //processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) - output bool //?? + CancelOnFilter string `yaml:"cancel_on,omitempty"` // a filter that, if matched, kills the bucket + leakspeed time.Duration // internal representation of `Leakspeed` + duration time.Duration // internal representation of `Duration` + ret chan types.Event // the bucket-specific output chan for overflows + processors []Processor // processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) + output bool // ?? ScenarioVersion string `yaml:"version,omitempty"` - hash string `yaml:"-"` - Simulated bool `yaml:"simulated"` //Set to true if the scenario instantiating the bucket was in the exclusion list - tomb *tomb.Tomb `yaml:"-"` - wgPour *sync.WaitGroup `yaml:"-"` - wgDumpState *sync.WaitGroup `yaml:"-"` + hash string + Simulated bool `yaml:"simulated"` // Set to true if the scenario instantiating the bucket was in the exclusion list + tomb *tomb.Tomb + wgPour *sync.WaitGroup + wgDumpState *sync.WaitGroup orderEvent bool } // we use one NameGenerator for all the future buckets var seed namegenerator.Generator = namegenerator.NewNameGenerator(time.Now().UTC().UnixNano()) +func validateLeakyType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity <= 0 { // capacity must be a positive int + return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) + } + + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for leaky") + } + + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) + } + + return nil +} + +func validateCounterType(bucketFactory *BucketFactory) error { + if bucketFactory.Duration == "" { + return errors.New("duration can't be empty for counter") + } + + if bucketFactory.duration == 0 { + return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) + } + + if bucketFactory.Capacity != -1 { + return errors.New("counter bucket must have -1 capacity") + } + + return nil +} + +func validateTriggerType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity != 0 { + return errors.New("trigger bucket must have 0 capacity") + } + + return nil +} + +func validateConditionalType(bucketFactory *BucketFactory) error { + if bucketFactory.ConditionalOverflow == "" { + return errors.New("conditional bucket must have a condition") + } + + if bucketFactory.Capacity != -1 { + bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") + } + + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for conditional bucket") + } + + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) + } + + return nil +} + +func validateBayesianType(bucketFactory *BucketFactory) error { + if bucketFactory.BayesianConditions == nil { + return errors.New("bayesian bucket must have bayesian conditions") + } + + if bucketFactory.BayesianPrior == 0 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold == 0 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.BayesianPrior > 1 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold > 1 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.Capacity != -1 { + return errors.New("bayesian bucket must have capacity -1") + } + + return nil +} + func ValidateFactory(bucketFactory *BucketFactory) error { if bucketFactory.Name == "" { - return fmt.Errorf("bucket must have name") + return errors.New("bucket must have name") } + if bucketFactory.Description == "" { - return fmt.Errorf("description is mandatory") + return errors.New("description is mandatory") } - if bucketFactory.Type == "leaky" { - if bucketFactory.Capacity <= 0 { //capacity must be a positive int - return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) - } - if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for leaky") - } - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "counter" { - if bucketFactory.Duration == "" { - return fmt.Errorf("duration can't be empty for counter") - } - if bucketFactory.duration == 0 { - return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) - } - if bucketFactory.Capacity != -1 { - return fmt.Errorf("counter bucket must have -1 capacity") - } - } else if bucketFactory.Type == "trigger" { - if bucketFactory.Capacity != 0 { - return fmt.Errorf("trigger bucket must have 0 capacity") - } - } else if bucketFactory.Type == "conditional" { - if bucketFactory.ConditionalOverflow == "" { - return fmt.Errorf("conditional bucket must have a condition") - } - if bucketFactory.Capacity != -1 { - bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") - } - if bucketFactory.LeakSpeed == "" { - return fmt.Errorf("leakspeed can't be empty for conditional bucket") - } - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "bayesian" { - if bucketFactory.BayesianConditions == nil { - return fmt.Errorf("bayesian bucket must have bayesian conditions") - } - if bucketFactory.BayesianPrior == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + + switch bucketFactory.Type { + case "leaky": + if err := validateLeakyType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianThreshold == 0 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + case "counter": + if err := validateCounterType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianPrior > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero prior") + case "trigger": + if err := validateTriggerType(bucketFactory); err != nil { + return err } - if bucketFactory.BayesianThreshold > 1 { - return fmt.Errorf("bayesian bucket must have a valid, non-zero threshold") + case "conditional": + if err := validateConditionalType(bucketFactory); err != nil { + return err } - if bucketFactory.Capacity != -1 { - return fmt.Errorf("bayesian bucket must have capacity -1") + case "bayesian": + if err := validateBayesianType(bucketFactory); err != nil { + return err } - } else { + default: return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type) } @@ -155,26 +210,31 @@ func ValidateFactory(bucketFactory *BucketFactory) error { runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } default: - //Compile the scope filter + // Compile the scope filter var ( runTimeFilter *vm.Program err error ) + if bucketFactory.ScopeType.Filter != "" { if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("Error compiling the scope filter: %s", err) + return fmt.Errorf("error compiling the scope filter: %w", err) } + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter } } + return nil } @@ -185,77 +245,86 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str ) response = make(chan types.Event, 1) + for _, f := range files { log.Debugf("Loading '%s'", f) + if !strings.HasSuffix(f, ".yaml") && !strings.HasSuffix(f, ".yml") { log.Debugf("Skipping %s : not a yaml file", f) continue } - //process the yaml + // process the yaml bucketConfigurationFile, err := os.Open(f) if err != nil { log.Errorf("Can't access leaky configuration file %s", f) return nil, nil, err } + defer bucketConfigurationFile.Close() dec := yaml.NewDecoder(bucketConfigurationFile) dec.SetStrict(true) + for { bucketFactory := BucketFactory{} + err = dec.Decode(&bucketFactory) if err != nil { if !errors.Is(err, io.EOF) { - log.Errorf("Bad yaml in %s : %v", f, err) - return nil, nil, fmt.Errorf("bad yaml in %s : %v", f, err) + log.Errorf("Bad yaml in %s: %v", f, err) + return nil, nil, fmt.Errorf("bad yaml in %s: %w", f, err) } + log.Tracef("End of yaml file") + break } + bucketFactory.DataDir = hub.GetDataDir() - //check empty + // check empty if bucketFactory.Name == "" { log.Errorf("Won't load nameless bucket") - return nil, nil, fmt.Errorf("nameless bucket") + return nil, nil, errors.New("nameless bucket") } - //check compat + // check compat if bucketFactory.FormatVersion == "" { log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, f) bucketFactory.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(bucketFactory.FormatVersion, cwversion.Constraint_scenario) + + ok, err := constraint.Satisfies(bucketFactory.FormatVersion, constraint.Scenario) if err != nil { - return nil, nil, fmt.Errorf("failed to check version : %s", err) + return nil, nil, fmt.Errorf("failed to check version: %w", err) } + if !ok { - log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, cwversion.Constraint_scenario) + log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, constraint.Scenario) continue } bucketFactory.Filename = filepath.Clean(f) bucketFactory.BucketName = seed.Generate() bucketFactory.ret = response - hubItem, err := hub.GetItemByPath(cwhub.SCENARIOS, bucketFactory.Filename) - if err != nil { - log.Errorf("scenario %s (%s) couldn't be find in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) + + hubItem := hub.GetItemByPath(bucketFactory.Filename) + if hubItem == nil { + log.Errorf("scenario %s (%s) could not be found in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) } else { if cscfg.SimulationConfig != nil { bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(hubItem.Name) } - if hubItem != nil { - bucketFactory.ScenarioVersion = hubItem.State.LocalVersion - bucketFactory.hash = hubItem.State.LocalHash - } else { - log.Errorf("scenario %s (%s) couldn't be find in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) - } + + bucketFactory.ScenarioVersion = hubItem.State.LocalVersion + bucketFactory.hash = hubItem.State.LocalHash } bucketFactory.wgDumpState = buckets.wgDumpState bucketFactory.wgPour = buckets.wgPour + err = LoadBucket(&bucketFactory, tomb) if err != nil { - log.Errorf("Failed to load bucket %s : %v", bucketFactory.Name, err) - return nil, nil, fmt.Errorf("loading of %s failed : %v", bucketFactory.Name, err) + log.Errorf("Failed to load bucket %s: %v", bucketFactory.Name, err) + return nil, nil, fmt.Errorf("loading of %s failed: %w", bucketFactory.Name, err) } bucketFactory.orderEvent = orderEvent @@ -265,21 +334,24 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str } if err := alertcontext.NewAlertContext(cscfg.ContextToSend, cscfg.ConsoleContextValueLength); err != nil { - return nil, nil, fmt.Errorf("unable to load alert context: %s", err) + return nil, nil, fmt.Errorf("unable to load alert context: %w", err) } log.Infof("Loaded %d scenarios", len(ret)) + return ret, response, nil } /* Init recursively process yaml files from a directory and loads them as BucketFactory */ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { var err error + if bucketFactory.Debug { - var clog = log.New() + clog := log.New() if err := types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating bucket-specific logger : %s", err) + return fmt.Errorf("while creating bucket-specific logger: %w", err) } + clog.SetLevel(log.DebugLevel) bucketFactory.logger = clog.WithFields(log.Fields{ "cfg": bucketFactory.BucketName, @@ -295,35 +367,37 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.LeakSpeed != "" { if bucketFactory.leakspeed, err = time.ParseDuration(bucketFactory.LeakSpeed); err != nil { - return fmt.Errorf("bad leakspeed '%s' in %s : %v", bucketFactory.LeakSpeed, bucketFactory.Filename, err) + return fmt.Errorf("bad leakspeed '%s' in %s: %w", bucketFactory.LeakSpeed, bucketFactory.Filename, err) } } else { bucketFactory.leakspeed = time.Duration(0) } + if bucketFactory.Duration != "" { if bucketFactory.duration, err = time.ParseDuration(bucketFactory.Duration); err != nil { - return fmt.Errorf("invalid Duration '%s' in %s : %v", bucketFactory.Duration, bucketFactory.Filename, err) + return fmt.Errorf("invalid Duration '%s' in %s: %w", bucketFactory.Duration, bucketFactory.Filename, err) } } if bucketFactory.Filter == "" { bucketFactory.logger.Warning("Bucket without filter, abort.") - return fmt.Errorf("bucket without filter directive") + return errors.New("bucket without filter directive") } + bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid filter '%s' in %s : %v", bucketFactory.Filter, bucketFactory.Filename, err) + return fmt.Errorf("invalid filter '%s' in %s: %w", bucketFactory.Filter, bucketFactory.Filename, err) } if bucketFactory.GroupBy != "" { bucketFactory.RunTimeGroupBy, err = expr.Compile(bucketFactory.GroupBy, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid groupby '%s' in %s : %v", bucketFactory.GroupBy, bucketFactory.Filename, err) + return fmt.Errorf("invalid groupby '%s' in %s: %w", bucketFactory.GroupBy, bucketFactory.Filename, err) } } bucketFactory.logger.Infof("Adding %s bucket", bucketFactory.Type) - //return the Holder corresponding to the type of bucket + // return the Holder corresponding to the type of bucket bucketFactory.processors = []Processor{} switch bucketFactory.Type { case "leaky": @@ -337,7 +411,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { case "bayesian": bucketFactory.processors = append(bucketFactory.processors, &DumbProcessor{}) default: - return fmt.Errorf("invalid type '%s' in %s : %v", bucketFactory.Type, bucketFactory.Filename, err) + return fmt.Errorf("invalid type '%s' in %s: %w", bucketFactory.Type, bucketFactory.Filename, err) } if bucketFactory.Distinct != "" { @@ -352,21 +426,25 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.OverflowFilter != "" { bucketFactory.logger.Tracef("Adding an overflow filter") + filovflw, err := NewOverflowFilter(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating overflow_filter : %s", err) - return fmt.Errorf("error creating overflow_filter : %s", err) + return fmt.Errorf("error creating overflow_filter: %w", err) } + bucketFactory.processors = append(bucketFactory.processors, filovflw) } if bucketFactory.Blackhole != "" { bucketFactory.logger.Tracef("Adding blackhole.") + blackhole, err := NewBlackhole(bucketFactory) if err != nil { bucketFactory.logger.Errorf("Error creating blackhole : %s", err) - return fmt.Errorf("error creating blackhole : %s", err) + return fmt.Errorf("error creating blackhole : %w", err) } + bucketFactory.processors = append(bucketFactory.processors, blackhole) } @@ -380,87 +458,98 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.processors = append(bucketFactory.processors, &BayesianBucket{}) } - if len(bucketFactory.Data) > 0 { - for _, data := range bucketFactory.Data { - if data.DestPath == "" { - bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) - continue - } - err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) - if err != nil { - bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) - } - if data.Type == "regexp" { //cache only makes sense for regexp - exprhelpers.RegexpCacheInit(data.DestPath, *data) - } + for _, data := range bucketFactory.Data { + if data.DestPath == "" { + bucketFactory.logger.Errorf("no dest_file provided for '%s'", bucketFactory.Name) + continue + } + + err = exprhelpers.FileInit(bucketFactory.DataDir, data.DestPath, data.Type) + if err != nil { + bucketFactory.logger.Errorf("unable to init data for file '%s': %s", data.DestPath, err) + } + + if data.Type == "regexp" { // cache only makes sense for regexp + exprhelpers.RegexpCacheInit(data.DestPath, *data) } } bucketFactory.output = false if err := ValidateFactory(bucketFactory); err != nil { - return fmt.Errorf("invalid bucket from %s : %v", bucketFactory.Filename, err) + return fmt.Errorf("invalid bucket from %s: %w", bucketFactory.Filename, err) } + bucketFactory.tomb = tomb return nil - } func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFactory) error { var state map[string]Leaky + body, err := os.ReadFile(file) if err != nil { - return fmt.Errorf("can't state file %s : %s", file, err) + return fmt.Errorf("can't read state file %s: %w", file, err) } + if err := json.Unmarshal(body, &state); err != nil { - return fmt.Errorf("can't unmarshal state file %s : %s", file, err) + return fmt.Errorf("can't parse state file %s: %w", file, err) } + for k, v := range state { var tbucket *Leaky + log.Debugf("Reloading bucket %s", k) + val, ok := buckets.Bucket_map.Load(k) if ok { - log.Fatalf("key %s already exists : %+v", k, val) + return fmt.Errorf("key %s already exists: %+v", k, val) } - //find back our holder + // find back our holder found := false + for _, h := range bucketFactories { - if h.Name == v.Name { - log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) - //check in which mode the bucket was - if v.Mode == types.TIMEMACHINE { - tbucket = NewTimeMachine(h) - } else if v.Mode == types.LIVE { - tbucket = NewLeaky(h) - } else { - log.Errorf("Unknown bucket type : %d", v.Mode) - } - /*Trying to restore queue state*/ - tbucket.Queue = v.Queue - /*Trying to set the limiter to the saved values*/ - tbucket.Limiter.Load(v.SerializedState) - tbucket.In = make(chan *types.Event) - tbucket.Mapkey = k - tbucket.Signal = make(chan bool, 1) - tbucket.First_ts = v.First_ts - tbucket.Last_ts = v.Last_ts - tbucket.Ovflw_ts = v.Ovflw_ts - tbucket.Total_count = v.Total_count - buckets.Bucket_map.Store(k, tbucket) - h.tomb.Go(func() error { - return LeakRoutine(tbucket) - }) - <-tbucket.Signal - found = true - break + if h.Name != v.Name { + continue + } + + log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) + // check in which mode the bucket was + if v.Mode == types.TIMEMACHINE { + tbucket = NewTimeMachine(h) + } else if v.Mode == types.LIVE { + tbucket = NewLeaky(h) + } else { + log.Errorf("Unknown bucket type : %d", v.Mode) } + /*Trying to restore queue state*/ + tbucket.Queue = v.Queue + /*Trying to set the limiter to the saved values*/ + tbucket.Limiter.Load(v.SerializedState) + tbucket.In = make(chan *types.Event) + tbucket.Mapkey = k + tbucket.Signal = make(chan bool, 1) + tbucket.First_ts = v.First_ts + tbucket.Last_ts = v.Last_ts + tbucket.Ovflw_ts = v.Ovflw_ts + tbucket.Total_count = v.Total_count + buckets.Bucket_map.Store(k, tbucket) + h.tomb.Go(func() error { + return LeakRoutine(tbucket) + }) + <-tbucket.Signal + + found = true + + break } + if !found { - log.Fatalf("Unable to find holder for bucket %s : %s", k, spew.Sdump(v)) + return fmt.Errorf("unable to find holder for bucket %s: %s", k, spew.Sdump(v)) } } log.Infof("Restored %d buckets from dump", len(state)) - return nil + return nil } diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index ae7a86a4e4e..2858d8b5635 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -85,7 +85,7 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) defer buckets.wgDumpState.Done() if outputdir == "" { - return "", fmt.Errorf("empty output dir for dump bucket state") + return "", errors.New("empty output dir for dump bucket state") } tmpFd, err := os.CreateTemp(os.TempDir(), "crowdsec-buckets-dump-") if err != nil { @@ -132,11 +132,11 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) }) bbuckets, err := json.MarshalIndent(serialized, "", " ") if err != nil { - return "", fmt.Errorf("Failed to unmarshal buckets : %s", err) + return "", fmt.Errorf("failed to parse buckets: %s", err) } size, err := tmpFd.Write(bbuckets) if err != nil { - return "", fmt.Errorf("failed to write temp file : %s", err) + return "", fmt.Errorf("failed to write temp file: %s", err) } log.Infof("Serialized %d live buckets (+%d expired) in %d bytes to %s", len(serialized), discard, size, tmpFd.Name()) serialized = nil @@ -203,7 +203,7 @@ func PourItemToBucket(bucket *Leaky, holder BucketFactory, buckets *Buckets, par var d time.Time err = d.UnmarshalText([]byte(parsed.MarshaledTime)) if err != nil { - holder.logger.Warningf("Failed unmarshaling event time (%s) : %v", parsed.MarshaledTime, err) + holder.logger.Warningf("Failed to parse event time (%s) : %v", parsed.MarshaledTime, err) } if d.After(lastTs.Add(bucket.Duration)) { bucket.logger.Tracef("bucket is expired (curr event: %s, bucket deadline: %s), kill", d, lastTs.Add(bucket.Duration)) @@ -298,7 +298,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc BucketPourCache["OK"] = append(BucketPourCache["OK"], evt.(types.Event)) } //find the relevant holders (scenarios) - for idx := 0; idx < len(holders); idx++ { + for idx := range holders { //for idx, holder := range holders { //evaluate bucket's condition diff --git a/pkg/leakybucket/manager_run_test.go b/pkg/leakybucket/manager_run_test.go index 27b665f750c..f3fe08b697a 100644 --- a/pkg/leakybucket/manager_run_test.go +++ b/pkg/leakybucket/manager_run_test.go @@ -5,9 +5,10 @@ import ( "testing" "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) func expectBucketCount(buckets *Buckets, expected int) error { @@ -20,7 +21,6 @@ func expectBucketCount(buckets *Buckets, expected int) error { return fmt.Errorf("expected %d live buckets, got %d", expected, count) } return nil - } func TestGCandDump(t *testing.T) { @@ -29,7 +29,7 @@ func TestGCandDump(t *testing.T) { tomb = &tomb.Tomb{} ) - var Holders = []BucketFactory{ + Holders := []BucketFactory{ //one overflowing soon + bh { Name: "test_counter_fast", @@ -80,7 +80,7 @@ func TestGCandDump(t *testing.T) { log.Printf("Pouring to bucket") - var in = types.Event{Parsed: map[string]string{"something": "something"}} + in := types.Event{Parsed: map[string]string{"something": "something"}} //pour an item that will go to leaky + counter ok, err := PourItemToHolders(in, Holders, buckets) if err != nil { @@ -156,7 +156,7 @@ func TestShutdownBuckets(t *testing.T) { log.Printf("Pouring to bucket") - var in = types.Event{Parsed: map[string]string{"something": "something"}} + in := types.Event{Parsed: map[string]string{"something": "something"}} //pour an item that will go to leaky + counter ok, err := PourItemToHolders(in, Holders, buckets) if err != nil { @@ -178,5 +178,4 @@ func TestShutdownBuckets(t *testing.T) { if err := expectBucketCount(buckets, 2); err != nil { t.Fatal(err) } - } diff --git a/pkg/leakybucket/overflow_filter.go b/pkg/leakybucket/overflow_filter.go index 8ec701a3400..01dd491ed41 100644 --- a/pkg/leakybucket/overflow_filter.go +++ b/pkg/leakybucket/overflow_filter.go @@ -3,8 +3,8 @@ package leakybucket import ( "fmt" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 80226aafb2a..39b0e6a0ec4 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -1,6 +1,7 @@ package leakybucket import ( + "errors" "fmt" "net" "sort" @@ -18,103 +19,131 @@ import ( // SourceFromEvent extracts and formats a valid models.Source object from an Event func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { - srcs := make(map[string]models.Source) /*if it's already an overflow, we have properly formatted sources. we can just twitch them to reflect the requested scope*/ if evt.Type == types.OVFLW { + return overflowEventSources(evt, leaky) + } - for k, v := range evt.Overflow.Sources { + return eventSources(evt, leaky) +} - /*the scopes are already similar, nothing to do*/ - if leaky.scopeType.Scope == *v.Scope { - srcs[k] = v - continue - } +func overflowEventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) - /*The bucket requires a decision on scope Range */ - if leaky.scopeType.Scope == types.Range { - /*the original bucket was target IPs, check that we do have range*/ - if *v.Scope == types.Ip { - src := models.Source{} - src.AsName = v.AsName - src.AsNumber = v.AsNumber - src.Cn = v.Cn - src.Latitude = v.Latitude - src.Longitude = v.Longitude - src.Range = v.Range - src.Value = new(string) - src.Scope = new(string) - *src.Scope = leaky.scopeType.Scope - *src.Value = "" - if v.Range != "" { - *src.Value = v.Range - } - if leaky.scopeType.RunTimeFilter != nil { - retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) - if err != nil { - return srcs, fmt.Errorf("while running scope filter: %w", err) - } - value, ok := retValue.(string) - if !ok { - value = "" - } - src.Value = &value + for k, v := range evt.Overflow.Sources { + /*the scopes are already similar, nothing to do*/ + if leaky.scopeType.Scope == *v.Scope { + srcs[k] = v + continue + } + + /*The bucket requires a decision on scope Range */ + if leaky.scopeType.Scope == types.Range { + /*the original bucket was target IPs, check that we do have range*/ + if *v.Scope == types.Ip { + src := models.Source{} + src.AsName = v.AsName + src.AsNumber = v.AsNumber + src.Cn = v.Cn + src.Latitude = v.Latitude + src.Longitude = v.Longitude + src.Range = v.Range + src.Value = new(string) + src.Scope = new(string) + *src.Scope = leaky.scopeType.Scope + *src.Value = "" + + if v.Range != "" { + *src.Value = v.Range + } + + if leaky.scopeType.RunTimeFilter != nil { + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) + if err != nil { + return srcs, fmt.Errorf("while running scope filter: %w", err) } - if *src.Value != "" { - srcs[*src.Value] = src - } else { - log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) + + value, ok := retValue.(string) + if !ok { + value = "" } + + src.Value = &value + } + + if *src.Value != "" { + srcs[*src.Value] = src } else { - log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", - leaky.Name, *v.Scope, *v.Value) + log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) } + } else { + log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", + leaky.Name, *v.Scope, *v.Value) } } - return srcs, nil } + + return srcs, nil +} + +func eventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) + src := models.Source{} + switch leaky.scopeType.Scope { case types.Range, types.Ip: v, ok := evt.Meta["source_ip"] if !ok { return srcs, fmt.Errorf("scope is %s but Meta[source_ip] doesn't exist", leaky.scopeType.Scope) } + if net.ParseIP(v) == nil { return srcs, fmt.Errorf("scope is %s but '%s' isn't a valid ip", leaky.scopeType.Scope, v) } + src.IP = v src.Scope = &leaky.scopeType.Scope + if v, ok := evt.Enriched["ASNumber"]; ok { src.AsNumber = v } else if v, ok := evt.Enriched["ASNNumber"]; ok { src.AsNumber = v } + if v, ok := evt.Enriched["IsoCode"]; ok { src.Cn = v } + if v, ok := evt.Enriched["ASNOrg"]; ok { src.AsName = v } + if v, ok := evt.Enriched["Latitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad latitude %s : %s", v, err) } + src.Latitude = float32(l) } + if v, ok := evt.Enriched["Longitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad longitude %s : %s", v, err) } + src.Longitude = float32(l) } + if v, ok := evt.Meta["SourceRange"]; ok && v != "" { _, ipNet, err := net.ParseCIDR(v) if err != nil { - return srcs, fmt.Errorf("Declared range %s of %s can't be parsed", v, src.IP) + return srcs, fmt.Errorf("declared range %s of %s can't be parsed", v, src.IP) } + if ipNet != nil { src.Range = ipNet.String() leaky.logger.Tracef("Valid range from %s : %s", src.IP, src.Range) @@ -124,6 +153,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e src.Value = &src.IP } else if leaky.scopeType.Scope == types.Range { src.Value = &src.Range + if leaky.scopeType.RunTimeFilter != nil { retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { @@ -134,14 +164,17 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value } } + srcs[*src.Value] = src default: if leaky.scopeType.RunTimeFilter == nil { - return srcs, fmt.Errorf("empty scope information") + return srcs, errors.New("empty scope information") } + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { return srcs, fmt.Errorf("while running scope filter: %w", err) @@ -151,30 +184,34 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value src.Scope = new(string) *src.Scope = leaky.scopeType.Scope srcs[*src.Value] = src } + return srcs, nil } // EventsFromQueue iterates the queue to collect & prepare meta-datas from alert func EventsFromQueue(queue *types.Queue) []*models.Event { - events := []*models.Event{} for _, evt := range queue.Queue { if evt.Meta == nil { continue } + meta := models.Meta{} - //we want consistence + // we want consistence skeys := make([]string, 0, len(evt.Meta)) for k := range evt.Meta { skeys = append(skeys, k) } + sort.Strings(skeys) + for _, k := range skeys { v := evt.Meta[k] subMeta := models.MetaItems0{Key: k, Value: v} @@ -185,15 +222,16 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { ovflwEvent := models.Event{ Meta: meta, } - //either MarshaledTime is present and is extracted from log + // either MarshaledTime is present and is extracted from log if evt.MarshaledTime != "" { tmpTimeStamp := evt.MarshaledTime ovflwEvent.Timestamp = &tmpTimeStamp - } else if !evt.Time.IsZero() { //or .Time has been set during parse as time.Now().UTC() + } else if !evt.Time.IsZero() { // or .Time has been set during parse as time.Now().UTC() ovflwEvent.Timestamp = new(string) + raw, err := evt.Time.MarshalText() if err != nil { - log.Warningf("while marshaling time '%s' : %s", evt.Time.String(), err) + log.Warningf("while serializing time '%s' : %s", evt.Time.String(), err) } else { *ovflwEvent.Timestamp = string(raw) } @@ -203,14 +241,16 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { events = append(events, &ovflwEvent) } + return events } // alertFormatSource iterates over the queue to collect sources func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Source, string, error) { - var sources = make(map[string]models.Source) var source_type string + sources := make(map[string]models.Source) + log.Debugf("Formatting (%s) - scope Info : scope_type:%s / scope_filter:%s", leaky.Name, leaky.scopeType.Scope, leaky.scopeType.Filter) for _, evt := range queue.Queue { @@ -218,17 +258,21 @@ func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Sour if err != nil { return nil, "", fmt.Errorf("while extracting scope from bucket %s: %w", leaky.Name, err) } + for key, src := range srcs { if source_type == types.Undefined { source_type = *src.Scope } + if *src.Scope != source_type { return nil, "", fmt.Errorf("event has multiple source types : %s != %s", *src.Scope, source_type) } + sources[key] = src } } + return sources, source_type, nil } @@ -242,12 +286,14 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { */ start_at, err := leaky.First_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal start ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize start ts %s : %s", leaky.First_ts.String(), err) } + stop_at, err := leaky.Ovflw_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal ovflw ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize ovflw ts %s : %s", leaky.First_ts.String(), err) } + capacity := int32(leaky.Capacity) EventsCount := int32(leaky.Total_count) leakSpeed := leaky.Leakspeed.String() @@ -265,20 +311,22 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { StopAt: &stopAt, Simulated: &leaky.Simulated, } + if leaky.BucketConfig == nil { - return runtimeAlert, fmt.Errorf("leaky.BucketConfig is nil") + return runtimeAlert, errors.New("leaky.BucketConfig is nil") } - //give information about the bucket + // give information about the bucket runtimeAlert.Mapkey = leaky.Mapkey - //Get the sources from Leaky/Queue + // Get the sources from Leaky/Queue sources, source_scope, err := alertFormatSource(leaky, queue) if err != nil { return runtimeAlert, fmt.Errorf("unable to collect sources from bucket: %w", err) } + runtimeAlert.Sources = sources - //Include source info in format string + // Include source info in format string sourceStr := "UNKNOWN" if len(sources) > 1 { sourceStr = fmt.Sprintf("%d sources", len(sources)) @@ -290,20 +338,23 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { } *apiAlert.Message = fmt.Sprintf("%s %s performed '%s' (%d events over %s) at %s", source_scope, sourceStr, leaky.Name, leaky.Total_count, leaky.Ovflw_ts.Sub(leaky.First_ts), leaky.Last_ts) - //Get the events from Leaky/Queue + // Get the events from Leaky/Queue apiAlert.Events = EventsFromQueue(queue) + var warnings []error + apiAlert.Meta, warnings = alertcontext.EventToContext(leaky.Queue.GetQueue()) for _, w := range warnings { log.Warningf("while extracting context from bucket %s : %s", leaky.Name, w) } - //Loop over the Sources and generate appropriate number of ApiAlerts + // Loop over the Sources and generate appropriate number of ApiAlerts for _, srcValue := range sources { newApiAlert := apiAlert srcCopy := srcValue newApiAlert.Source = &srcCopy - if v, ok := leaky.BucketConfig.Labels["remediation"]; ok && v == true { + + if v, ok := leaky.BucketConfig.Labels["remediation"]; ok && v == true { //nolint:revive newApiAlert.Remediation = true } @@ -312,6 +363,7 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { log.Errorf("->%s", spew.Sdump(newApiAlert)) log.Fatalf("error : %s", err) } + runtimeAlert.APIAlerts = append(runtimeAlert.APIAlerts, newApiAlert) } @@ -322,5 +374,6 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { if leaky.Reprocess { runtimeAlert.Reprocess = true } + return runtimeAlert, nil } diff --git a/pkg/leakybucket/reset_filter.go b/pkg/leakybucket/reset_filter.go index 5884bf4a10c..452ccc085b1 100644 --- a/pkg/leakybucket/reset_filter.go +++ b/pkg/leakybucket/reset_filter.go @@ -3,8 +3,8 @@ package leakybucket import ( "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -82,22 +82,22 @@ func (u *CancelOnFilter) OnBucketInit(bucketFactory *BucketFactory) error { cancelExprCacheLock.Unlock() u.CancelOnFilter = compiled.CancelOnFilter return nil - } else { - cancelExprCacheLock.Unlock() - //release the lock during compile + } - compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - bucketFactory.logger.Errorf("reset_filter compile error : %s", err) - return err - } - u.CancelOnFilter = compiledExpr.CancelOnFilter - if bucketFactory.Debug { - u.Debug = true - } - cancelExprCacheLock.Lock() - cancelExprCache[bucketFactory.CancelOnFilter] = compiledExpr - cancelExprCacheLock.Unlock() + cancelExprCacheLock.Unlock() + //release the lock during compile + + compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + bucketFactory.logger.Errorf("reset_filter compile error : %s", err) + return err } - return err + u.CancelOnFilter = compiledExpr.CancelOnFilter + if bucketFactory.Debug { + u.Debug = true + } + cancelExprCacheLock.Lock() + cancelExprCache[bucketFactory.CancelOnFilter] = compiledExpr + cancelExprCacheLock.Unlock() + return nil } diff --git a/pkg/leakybucket/timemachine.go b/pkg/leakybucket/timemachine.go index 266a8be7c69..34073d1cc5c 100644 --- a/pkg/leakybucket/timemachine.go +++ b/pkg/leakybucket/timemachine.go @@ -3,8 +3,9 @@ package leakybucket import ( "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) func TimeMachinePour(l *Leaky, msg types.Event) { @@ -23,7 +24,7 @@ func TimeMachinePour(l *Leaky, msg types.Event) { err = d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) return } diff --git a/pkg/leakybucket/trigger.go b/pkg/leakybucket/trigger.go index d50d7ecc732..d13e57856f9 100644 --- a/pkg/leakybucket/trigger.go +++ b/pkg/leakybucket/trigger.go @@ -3,8 +3,9 @@ package leakybucket import ( "time" - "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/types" ) type Trigger struct { @@ -15,25 +16,31 @@ func (t *Trigger) OnBucketPour(b *BucketFactory) func(types.Event, *Leaky) *type // Pour makes the bucket overflow all the time // TriggerPour unconditionally overflows return func(msg types.Event, l *Leaky) *types.Event { + now := time.Now().UTC() + if l.Mode == types.TIMEMACHINE { var d time.Time + err := d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) - d = time.Now().UTC() + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) + + d = now } + l.logger.Debugf("yay timemachine overflow time : %s --> %s", d, msg.MarshaledTime) l.Last_ts = d l.First_ts = d l.Ovflw_ts = d } else { - l.Last_ts = time.Now().UTC() - l.First_ts = time.Now().UTC() - l.Ovflw_ts = time.Now().UTC() + l.Last_ts = now + l.First_ts = now + l.Ovflw_ts = now } + l.Total_count = 1 - l.logger.Infof("Bucket overflow") + l.logger.Debug("Bucket overflow") l.Queue.Add(msg) l.Out <- l.Queue diff --git a/pkg/leakybucket/uniq.go b/pkg/leakybucket/uniq.go index 06d1e154a6f..0cc0583390b 100644 --- a/pkg/leakybucket/uniq.go +++ b/pkg/leakybucket/uniq.go @@ -3,8 +3,8 @@ package leakybucket import ( "sync" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -39,11 +39,9 @@ func (u *Uniq) OnBucketPour(bucketFactory *BucketFactory) func(types.Event, *Lea leaky.logger.Debugf("Uniq(%s) : ok", element) u.KeyCache[element] = true return &msg - - } else { - leaky.logger.Debugf("Uniq(%s) : ko, discard event", element) - return nil } + leaky.logger.Debugf("Uniq(%s) : ko, discard event", element) + return nil } } diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index e93870a2869..5a7af0bfa63 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -46,7 +46,7 @@ type pollResponse struct { ErrorMessage string `json:"error"` } -var errUnauthorized = fmt.Errorf("user is not authorized to use PAPI") +var errUnauthorized = errors.New("user is not authorized to use PAPI") const timeoutMessage = "no events before timeout" @@ -74,11 +74,9 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { } func (c *LongPollClient) poll() error { - logger := c.logger.WithField("method", "poll") resp, err := c.doQuery() - if err != nil { return err } @@ -95,7 +93,7 @@ func (c *LongPollClient) poll() error { logger.Errorf("failed to read response body: %s", err) return err } - logger.Errorf(string(bodyContent)) + logger.Error(string(bodyContent)) return errUnauthorized } return fmt.Errorf("unexpected status code: %d", resp.StatusCode) @@ -122,7 +120,7 @@ func (c *LongPollClient) poll() error { logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { + if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { logger.Debugf("got timeout message") return nil @@ -209,7 +207,7 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { c.logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { + if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { c.logger.Debugf("got timeout message") break @@ -225,7 +223,7 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { func NewLongPollClient(config LongPollClientConfig) (*LongPollClient, error) { var logger *log.Entry if config.Url == (url.URL{}) { - return nil, fmt.Errorf("url is required") + return nil, errors.New("url is required") } if config.Logger == nil { logger = log.WithField("component", "longpollclient") diff --git a/pkg/metabase/api.go b/pkg/metabase/api.go index bded4c9e83d..08e10188678 100644 --- a/pkg/metabase/api.go +++ b/pkg/metabase/api.go @@ -6,10 +6,10 @@ import ( "net/http" "time" - "github.com/crowdsecurity/go-cs-lib/version" - "github.com/dghubble/sling" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) type MBClient struct { @@ -38,7 +38,7 @@ var ( func NewMBClient(url string) (*MBClient, error) { httpClient := &http.Client{Timeout: 20 * time.Second} return &MBClient{ - CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", fmt.Sprintf("crowdsec/%s", version.String())), + CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", useragent.Default()), Client: httpClient, }, nil } diff --git a/pkg/metabase/metabase.go b/pkg/metabase/metabase.go index 837bab796d5..324a05666a1 100644 --- a/pkg/metabase/metabase.go +++ b/pkg/metabase/metabase.go @@ -70,12 +70,12 @@ func (m *Metabase) Init(containerName string, image string) error { switch m.Config.Database.Type { case "mysql": - return fmt.Errorf("'mysql' is not supported yet for cscli dashboard") + return errors.New("'mysql' is not supported yet for cscli dashboard") //DBConnectionURI = fmt.Sprintf("MB_DB_CONNECTION_URI=mysql://%s:%d/%s?user=%s&password=%s&allowPublicKeyRetrieval=true", remoteDBAddr, m.Config.Database.Port, m.Config.Database.DbName, m.Config.Database.User, m.Config.Database.Password) case "sqlite": m.InternalDBURL = metabaseSQLiteDBURL case "postgresql", "postgres", "pgsql": - return fmt.Errorf("'postgresql' is not supported yet by cscli dashboard") + return errors.New("'postgresql' is not supported yet by cscli dashboard") default: return fmt.Errorf("database '%s' not supported", m.Config.Database.Type) } diff --git a/pkg/models/add_alerts_request.go b/pkg/models/add_alerts_request.go index fd7246be066..a69934ef770 100644 --- a/pkg/models/add_alerts_request.go +++ b/pkg/models/add_alerts_request.go @@ -54,6 +54,11 @@ func (m AddAlertsRequest) ContextValidate(ctx context.Context, formats strfmt.Re for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/alert.go b/pkg/models/alert.go index ec769a1fbb1..895f5ad76e1 100644 --- a/pkg/models/alert.go +++ b/pkg/models/alert.go @@ -399,6 +399,11 @@ func (m *Alert) contextValidateDecisions(ctx context.Context, formats strfmt.Reg for i := 0; i < len(m.Decisions); i++ { if m.Decisions[i] != nil { + + if swag.IsZero(m.Decisions[i]) { // not required + return nil + } + if err := m.Decisions[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("decisions" + "." + strconv.Itoa(i)) @@ -419,6 +424,11 @@ func (m *Alert) contextValidateEvents(ctx context.Context, formats strfmt.Regist for i := 0; i < len(m.Events); i++ { if m.Events[i] != nil { + + if swag.IsZero(m.Events[i]) { // not required + return nil + } + if err := m.Events[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("events" + "." + strconv.Itoa(i)) @@ -469,6 +479,7 @@ func (m *Alert) contextValidateMeta(ctx context.Context, formats strfmt.Registry func (m *Alert) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/models/all_metrics.go b/pkg/models/all_metrics.go new file mode 100644 index 00000000000..5865070e8ef --- /dev/null +++ b/pkg/models/all_metrics.go @@ -0,0 +1,234 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// AllMetrics AllMetrics +// +// swagger:model AllMetrics +type AllMetrics struct { + + // lapi + Lapi *LapiMetrics `json:"lapi,omitempty"` + + // log processors metrics + LogProcessors []*LogProcessorsMetrics `json:"log_processors"` + + // remediation components metrics + RemediationComponents []*RemediationComponentsMetrics `json:"remediation_components"` +} + +// Validate validates this all metrics +func (m *AllMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateLapi(formats); err != nil { + res = append(res, err) + } + + if err := m.validateLogProcessors(formats); err != nil { + res = append(res, err) + } + + if err := m.validateRemediationComponents(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllMetrics) validateLapi(formats strfmt.Registry) error { + if swag.IsZero(m.Lapi) { // not required + return nil + } + + if m.Lapi != nil { + if err := m.Lapi.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("lapi") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("lapi") + } + return err + } + } + + return nil +} + +func (m *AllMetrics) validateLogProcessors(formats strfmt.Registry) error { + if swag.IsZero(m.LogProcessors) { // not required + return nil + } + + for i := 0; i < len(m.LogProcessors); i++ { + if swag.IsZero(m.LogProcessors[i]) { // not required + continue + } + + if m.LogProcessors[i] != nil { + if err := m.LogProcessors[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *AllMetrics) validateRemediationComponents(formats strfmt.Registry) error { + if swag.IsZero(m.RemediationComponents) { // not required + return nil + } + + for i := 0; i < len(m.RemediationComponents); i++ { + if swag.IsZero(m.RemediationComponents[i]) { // not required + continue + } + + if m.RemediationComponents[i] != nil { + if err := m.RemediationComponents[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// ContextValidate validate this all metrics based on the context it is used +func (m *AllMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateLapi(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateLogProcessors(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateRemediationComponents(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllMetrics) contextValidateLapi(ctx context.Context, formats strfmt.Registry) error { + + if m.Lapi != nil { + + if swag.IsZero(m.Lapi) { // not required + return nil + } + + if err := m.Lapi.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("lapi") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("lapi") + } + return err + } + } + + return nil +} + +func (m *AllMetrics) contextValidateLogProcessors(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.LogProcessors); i++ { + + if m.LogProcessors[i] != nil { + + if swag.IsZero(m.LogProcessors[i]) { // not required + return nil + } + + if err := m.LogProcessors[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("log_processors" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *AllMetrics) contextValidateRemediationComponents(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.RemediationComponents); i++ { + + if m.RemediationComponents[i] != nil { + + if swag.IsZero(m.RemediationComponents[i]) { // not required + return nil + } + + if err := m.RemediationComponents[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("remediation_components" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// MarshalBinary interface implementation +func (m *AllMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *AllMetrics) UnmarshalBinary(b []byte) error { + var res AllMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/base_metrics.go b/pkg/models/base_metrics.go new file mode 100644 index 00000000000..94691ea233e --- /dev/null +++ b/pkg/models/base_metrics.go @@ -0,0 +1,215 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// BaseMetrics BaseMetrics +// +// swagger:model BaseMetrics +type BaseMetrics struct { + + // feature flags (expected to be empty for remediation components) + FeatureFlags []string `json:"feature_flags"` + + // metrics details + Metrics []*DetailedMetrics `json:"metrics"` + + // os + Os *OSversion `json:"os,omitempty"` + + // UTC timestamp of the startup of the software + // Required: true + UtcStartupTimestamp *int64 `json:"utc_startup_timestamp"` + + // version of the remediation component + // Required: true + // Max Length: 255 + Version *string `json:"version"` +} + +// Validate validates this base metrics +func (m *BaseMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateMetrics(formats); err != nil { + res = append(res, err) + } + + if err := m.validateOs(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUtcStartupTimestamp(formats); err != nil { + res = append(res, err) + } + + if err := m.validateVersion(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *BaseMetrics) validateMetrics(formats strfmt.Registry) error { + if swag.IsZero(m.Metrics) { // not required + return nil + } + + for i := 0; i < len(m.Metrics); i++ { + if swag.IsZero(m.Metrics[i]) { // not required + continue + } + + if m.Metrics[i] != nil { + if err := m.Metrics[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("metrics" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("metrics" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *BaseMetrics) validateOs(formats strfmt.Registry) error { + if swag.IsZero(m.Os) { // not required + return nil + } + + if m.Os != nil { + if err := m.Os.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("os") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("os") + } + return err + } + } + + return nil +} + +func (m *BaseMetrics) validateUtcStartupTimestamp(formats strfmt.Registry) error { + + if err := validate.Required("utc_startup_timestamp", "body", m.UtcStartupTimestamp); err != nil { + return err + } + + return nil +} + +func (m *BaseMetrics) validateVersion(formats strfmt.Registry) error { + + if err := validate.Required("version", "body", m.Version); err != nil { + return err + } + + if err := validate.MaxLength("version", "body", *m.Version, 255); err != nil { + return err + } + + return nil +} + +// ContextValidate validate this base metrics based on the context it is used +func (m *BaseMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateMetrics(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateOs(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *BaseMetrics) contextValidateMetrics(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Metrics); i++ { + + if m.Metrics[i] != nil { + + if swag.IsZero(m.Metrics[i]) { // not required + return nil + } + + if err := m.Metrics[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("metrics" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("metrics" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *BaseMetrics) contextValidateOs(ctx context.Context, formats strfmt.Registry) error { + + if m.Os != nil { + + if swag.IsZero(m.Os) { // not required + return nil + } + + if err := m.Os.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("os") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("os") + } + return err + } + } + + return nil +} + +// MarshalBinary interface implementation +func (m *BaseMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *BaseMetrics) UnmarshalBinary(b []byte) error { + var res BaseMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/console_options.go b/pkg/models/console_options.go new file mode 100644 index 00000000000..87983ab1762 --- /dev/null +++ b/pkg/models/console_options.go @@ -0,0 +1,27 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/strfmt" +) + +// ConsoleOptions ConsoleOptions +// +// swagger:model ConsoleOptions +type ConsoleOptions []string + +// Validate validates this console options +func (m ConsoleOptions) Validate(formats strfmt.Registry) error { + return nil +} + +// ContextValidate validates this console options based on context it is used +func (m ConsoleOptions) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} diff --git a/pkg/models/detailed_metrics.go b/pkg/models/detailed_metrics.go new file mode 100644 index 00000000000..9e605ed8c88 --- /dev/null +++ b/pkg/models/detailed_metrics.go @@ -0,0 +1,173 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// DetailedMetrics DetailedMetrics +// +// swagger:model DetailedMetrics +type DetailedMetrics struct { + + // items + // Required: true + Items []*MetricsDetailItem `json:"items"` + + // meta + // Required: true + Meta *MetricsMeta `json:"meta"` +} + +// Validate validates this detailed metrics +func (m *DetailedMetrics) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateItems(formats); err != nil { + res = append(res, err) + } + + if err := m.validateMeta(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *DetailedMetrics) validateItems(formats strfmt.Registry) error { + + if err := validate.Required("items", "body", m.Items); err != nil { + return err + } + + for i := 0; i < len(m.Items); i++ { + if swag.IsZero(m.Items[i]) { // not required + continue + } + + if m.Items[i] != nil { + if err := m.Items[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *DetailedMetrics) validateMeta(formats strfmt.Registry) error { + + if err := validate.Required("meta", "body", m.Meta); err != nil { + return err + } + + if m.Meta != nil { + if err := m.Meta.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("meta") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("meta") + } + return err + } + } + + return nil +} + +// ContextValidate validate this detailed metrics based on the context it is used +func (m *DetailedMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateItems(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateMeta(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *DetailedMetrics) contextValidateItems(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Items); i++ { + + if m.Items[i] != nil { + + if swag.IsZero(m.Items[i]) { // not required + return nil + } + + if err := m.Items[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *DetailedMetrics) contextValidateMeta(ctx context.Context, formats strfmt.Registry) error { + + if m.Meta != nil { + + if err := m.Meta.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("meta") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("meta") + } + return err + } + } + + return nil +} + +// MarshalBinary interface implementation +func (m *DetailedMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *DetailedMetrics) UnmarshalBinary(b []byte) error { + var res DetailedMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/generate.go b/pkg/models/generate.go new file mode 100644 index 00000000000..502d6f3d2cf --- /dev/null +++ b/pkg/models/generate.go @@ -0,0 +1,4 @@ +package models + +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./localapi_swagger.yaml --target=../ + diff --git a/pkg/models/get_alerts_response.go b/pkg/models/get_alerts_response.go index 41b9d5afdbd..d4ea36e02c5 100644 --- a/pkg/models/get_alerts_response.go +++ b/pkg/models/get_alerts_response.go @@ -54,6 +54,11 @@ func (m GetAlertsResponse) ContextValidate(ctx context.Context, formats strfmt.R for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/get_decisions_response.go b/pkg/models/get_decisions_response.go index b65b950fc58..19437dc9b38 100644 --- a/pkg/models/get_decisions_response.go +++ b/pkg/models/get_decisions_response.go @@ -54,6 +54,11 @@ func (m GetDecisionsResponse) ContextValidate(ctx context.Context, formats strfm for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/helpers.go b/pkg/models/helpers.go index 8c082550d48..5bc3f2a28b3 100644 --- a/pkg/models/helpers.go +++ b/pkg/models/helpers.go @@ -1,27 +1,33 @@ package models -func (a *Alert) HasRemediation() bool { - return true -} +import ( + "fmt" + + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" +) + +const ( + // these are duplicated from pkg/types + // TODO XXX: de-duplicate + Ip = "Ip" + Range = "Range" + CscliImportOrigin = "cscli-import" +) func (a *Alert) GetScope() string { - if a.Source.Scope == nil { - return "" - } - return *a.Source.Scope + return a.Source.GetScope() } func (a *Alert) GetValue() string { - if a.Source.Value == nil { - return "" - } - return *a.Source.Value + return a.Source.GetValue() } func (a *Alert) GetScenario() string { if a.Scenario == nil { return "" } + return *a.Scenario } @@ -29,6 +35,7 @@ func (a *Alert) GetEventsCount() int32 { if a.EventsCount == nil { return 0 } + return *a.EventsCount } @@ -38,6 +45,7 @@ func (e *Event) GetMeta(key string) string { return meta.Value } } + return "" } @@ -47,6 +55,7 @@ func (a *Alert) GetMeta(key string) string { return meta.Value } } + return "" } @@ -54,6 +63,7 @@ func (s Source) GetValue() string { if s.Value == nil { return "" } + return *s.Value } @@ -61,6 +71,7 @@ func (s Source) GetScope() string { if s.Scope == nil { return "" } + return *s.Scope } @@ -69,8 +80,88 @@ func (s Source) GetAsNumberName() string { if s.AsNumber != "0" { ret += s.AsNumber } + if s.AsName != "" { ret += " " + s.AsName } + return ret } + +func (s *Source) String() string { + if s == nil || s.Scope == nil || *s.Scope == "" { + return "empty source" + } + + cn := s.Cn + + if s.AsNumber != "" { + cn += "/" + s.AsNumber + } + + if cn != "" { + cn = " (" + cn + ")" + } + + switch *s.Scope { + case Ip: + return "ip " + *s.Value + cn + case Range: + return "range " + *s.Value + cn + default: + return *s.Scope + " " + *s.Value + } +} + +func (a *Alert) FormatAsStrings(machineID string, logger *log.Logger) []string { + src := a.Source.String() + + msg := "empty scenario" + if a.Scenario != nil && *a.Scenario != "" { + msg = *a.Scenario + } else if a.Message != nil && *a.Message != "" { + msg = *a.Message + } + + reason := fmt.Sprintf("%s by %s", msg, src) + + if len(a.Decisions) == 0 { + return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} + } + + var retStr []string + + if a.Decisions[0].Origin != nil && *a.Decisions[0].Origin == CscliImportOrigin { + return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} + } + + for i, decisionItem := range a.Decisions { + decision := "" + if a.Simulated != nil && *a.Simulated { + decision = "(simulated alert)" + } else if decisionItem.Simulated != nil && *decisionItem.Simulated { + decision = "(simulated decision)" + } + + if logger.GetLevel() >= log.DebugLevel { + /*spew is expensive*/ + logger.Debug(spew.Sdump(decisionItem)) + } + + if len(a.Decisions) > 1 { + reason = fmt.Sprintf("%s for %d/%d decisions", msg, i+1, len(a.Decisions)) + } + + origin := *decisionItem.Origin + if machineID != "" { + origin = machineID + "/" + origin + } + + decision += fmt.Sprintf("%s %s on %s %s", *decisionItem.Duration, + *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) + retStr = append(retStr, + fmt.Sprintf("(%s) %s : %s", origin, reason, decision)) + } + + return retStr +} diff --git a/pkg/models/hub_item.go b/pkg/models/hub_item.go new file mode 100644 index 00000000000..c2bac3702c2 --- /dev/null +++ b/pkg/models/hub_item.go @@ -0,0 +1,56 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// HubItem HubItem +// +// swagger:model HubItem +type HubItem struct { + + // name of the hub item + Name string `json:"name,omitempty"` + + // status of the hub item (official, custom, tainted, etc.) + Status string `json:"status,omitempty"` + + // version of the hub item + Version string `json:"version,omitempty"` +} + +// Validate validates this hub item +func (m *HubItem) Validate(formats strfmt.Registry) error { + return nil +} + +// ContextValidate validates this hub item based on context it is used +func (m *HubItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *HubItem) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *HubItem) UnmarshalBinary(b []byte) error { + var res HubItem + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/hub_items.go b/pkg/models/hub_items.go new file mode 100644 index 00000000000..82388d5b97e --- /dev/null +++ b/pkg/models/hub_items.go @@ -0,0 +1,83 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// HubItems HubItems +// +// swagger:model HubItems +type HubItems map[string][]HubItem + +// Validate validates this hub items +func (m HubItems) Validate(formats strfmt.Registry) error { + var res []error + + for k := range m { + + if err := validate.Required(k, "body", m[k]); err != nil { + return err + } + + for i := 0; i < len(m[k]); i++ { + + if err := m[k][i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(k + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(k + "." + strconv.Itoa(i)) + } + return err + } + + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validate this hub items based on the context it is used +func (m HubItems) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + for k := range m { + + for i := 0; i < len(m[k]); i++ { + + if swag.IsZero(m[k][i]) { // not required + return nil + } + + if err := m[k][i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(k + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(k + "." + strconv.Itoa(i)) + } + return err + } + + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} diff --git a/pkg/models/lapi_metrics.go b/pkg/models/lapi_metrics.go new file mode 100644 index 00000000000..b56d92ef1f8 --- /dev/null +++ b/pkg/models/lapi_metrics.go @@ -0,0 +1,157 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// LapiMetrics LapiMetrics +// +// swagger:model LapiMetrics +type LapiMetrics struct { + BaseMetrics + + // console options + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *LapiMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.ConsoleOptions = dataAO1.ConsoleOptions + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m LapiMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + ConsoleOptions ConsoleOptions `json:"console_options,omitempty"` + } + + dataAO1.ConsoleOptions = m.ConsoleOptions + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this lapi metrics +func (m *LapiMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if err := m.validateConsoleOptions(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LapiMetrics) validateConsoleOptions(formats strfmt.Registry) error { + + if swag.IsZero(m.ConsoleOptions) { // not required + return nil + } + + if err := m.ConsoleOptions.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("console_options") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("console_options") + } + return err + } + + return nil +} + +// ContextValidate validate this lapi metrics based on the context it is used +func (m *LapiMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateConsoleOptions(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LapiMetrics) contextValidateConsoleOptions(ctx context.Context, formats strfmt.Registry) error { + + if err := m.ConsoleOptions.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("console_options") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("console_options") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *LapiMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *LapiMetrics) UnmarshalBinary(b []byte) error { + var res LapiMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/localapi_swagger.yaml b/pkg/models/localapi_swagger.yaml index 66132e5e36e..01bbe6f8bde 100644 --- a/pkg/models/localapi_swagger.yaml +++ b/pkg/models/localapi_swagger.yaml @@ -26,10 +26,10 @@ produces: paths: /decisions/stream: get: - description: Returns a list of new/expired decisions. Intended for bouncers that need to "stream" decisions + description: Returns a list of new/expired decisions. Intended for remediation component that need to "stream" decisions summary: getDecisionsStream tags: - - bouncers + - Remediation component operationId: getDecisionsStream deprecated: false produces: @@ -39,7 +39,7 @@ paths: in: query required: false type: boolean - description: 'If true, means that the bouncers is starting and a full list must be provided' + description: 'If true, means that the remediation component is starting and a full list must be provided' - name: scopes in: query required: false @@ -73,10 +73,10 @@ paths: security: - APIKeyAuthorizer: [] head: - description: Returns a list of new/expired decisions. Intended for bouncers that need to "stream" decisions + description: Returns a list of new/expired decisions. Intended for remediation component that need to "stream" decisions summary: GetDecisionsStream tags: - - bouncers + - Remediation component operationId: headDecisionsStream deprecated: false produces: @@ -100,7 +100,7 @@ paths: description: Returns information about existing decisions summary: getDecisions tags: - - bouncers + - Remediation component operationId: getDecisions deprecated: false produces: @@ -160,11 +160,13 @@ paths: description: "400 response" schema: $ref: "#/definitions/ErrorResponse" + security: + - APIKeyAuthorizer: [] head: description: Returns information about existing decisions summary: GetDecisions tags: - - bouncers + - Remediation component operationId: headDecisions deprecated: false produces: @@ -310,6 +312,9 @@ paths: '201': description: Watcher Created headers: {} + '202': + description: Watcher Validated + headers: {} '400': description: "400 response" schema: @@ -684,6 +689,36 @@ paths: $ref: "#/definitions/ErrorResponse" security: - JWTAuthorizer: [] + /usage-metrics: + post: + description: Post usage metrics from a LP or a bouncer + summary: Send usage metrics + tags: + - Remediation component + - watchers + operationId: usage-metrics + produces: + - application/json + parameters: + - name: body + in: body + required: true + schema: + $ref: '#/definitions/AllMetrics' + description: 'All metrics' + responses: + '200': + description: successful operation + schema: + $ref: '#/definitions/SuccessResponse' + headers: {} + '400': + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - APIKeyAuthorizer: [] + - JWTAuthorizer: [] definitions: WatcherRegistrationRequest: title: WatcherRegistrationRequest @@ -694,6 +729,10 @@ definitions: password: type: string format: password + registration_token: + type: string + minLength: 32 + maxLength: 255 required: - machine_id - password @@ -994,6 +1033,193 @@ definitions: type: string value: type: string + RemediationComponentsMetrics: + title: RemediationComponentsMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + type: + type: string + description: type of the remediation component + name: + type: string + description: name of the remediation component + last_pull: + type: integer + description: last pull date + LogProcessorsMetrics: + title: LogProcessorsMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + hub_items: + $ref: '#/definitions/HubItems' + datasources: + type: object + description: Number of datasources per type + additionalProperties: + type: integer + name: + type: string + description: name of the log processor + last_push: + type: integer + description: last push date + last_update: + type: integer + description: last update date + required: + - hub_items + - datasources + LapiMetrics: + title: LapiMetrics + type: object + allOf: + - $ref: '#/definitions/BaseMetrics' + - properties: + console_options: + $ref: '#/definitions/ConsoleOptions' + AllMetrics: + title: AllMetrics + type: object + properties: + remediation_components: + type: array + items: + $ref: '#/definitions/RemediationComponentsMetrics' + description: remediation components metrics + log_processors: + type: array + items: + $ref: '#/definitions/LogProcessorsMetrics' + description: log processors metrics + lapi: + $ref: '#/definitions/LapiMetrics' + BaseMetrics: + title: BaseMetrics + type: object + properties: + version: + type: string + description: version of the remediation component + maxLength: 255 + os: + $ref: '#/definitions/OSversion' + metrics: + type: array + items: + $ref: '#/definitions/DetailedMetrics' + description: metrics details + feature_flags: + type: array + items: + type: string + description: feature flags (expected to be empty for remediation components) + maxLength: 255 + utc_startup_timestamp: + type: integer + description: UTC timestamp of the startup of the software + required: + - version + - utc_startup_timestamp + OSversion: + title: OSversion + type: object + properties: + name: + type: string + description: name of the OS + maxLength: 255 + version: + type: string + description: version of the OS + maxLength: 255 + required: + - name + - version + DetailedMetrics: + type: object + title: DetailedMetrics + properties: + items: + type: array + items: + $ref: '#/definitions/MetricsDetailItem' + meta: + $ref: '#/definitions/MetricsMeta' + required: + - meta + - items + MetricsDetailItem: + title: MetricsDetailItem + type: object + properties: + name: + type: string + description: name of the metric + maxLength: 255 + value: + type: number + description: value of the metric + unit: + type: string + description: unit of the metric + maxLength: 255 + labels: + $ref: '#/definitions/MetricsLabels' + description: labels of the metric + required: + - name + - value + - unit + MetricsMeta: + title: MetricsMeta + type: object + properties: + window_size_seconds: + type: integer + description: Size, in seconds, of the window used to compute the metric + utc_now_timestamp: + type: integer + description: UTC timestamp of the current time + required: + - window_size_seconds + - utc_now_timestamp + MetricsLabels: + title: MetricsLabels + type: object + additionalProperties: + type: string + description: label of the metric + maxLength: 255 + ConsoleOptions: + title: ConsoleOptions + type: array + items: + type: string + description: enabled console options + HubItems: + title: HubItems + type: object + additionalProperties: + type: array + items: + $ref: '#/definitions/HubItem' + HubItem: + title: HubItem + type: object + properties: + name: + type: string + description: name of the hub item + version: + type: string + description: version of the hub item + status: + type: string + description: status of the hub item (official, custom, tainted, etc.) ErrorResponse: type: "object" required: @@ -1007,8 +1233,18 @@ definitions: description: "more detail on individual errors" title: "error response" description: "error response return by the API" + SuccessResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "message" + title: "success response" + description: "success response return by the API" tags: - - name: bouncers + - name: Remediation component description: 'Operations about decisions : bans, captcha, rate-limit etc.' - name: watchers description: 'Operations about watchers : cscli & crowdsec' diff --git a/pkg/models/log_processors_metrics.go b/pkg/models/log_processors_metrics.go new file mode 100644 index 00000000000..05b688fb994 --- /dev/null +++ b/pkg/models/log_processors_metrics.go @@ -0,0 +1,219 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// LogProcessorsMetrics LogProcessorsMetrics +// +// swagger:model LogProcessorsMetrics +type LogProcessorsMetrics struct { + BaseMetrics + + // Number of datasources per type + // Required: true + Datasources map[string]int64 `json:"datasources"` + + // hub items + // Required: true + HubItems HubItems `json:"hub_items"` + + // last push date + LastPush int64 `json:"last_push,omitempty"` + + // last update date + LastUpdate int64 `json:"last_update,omitempty"` + + // name of the log processor + Name string `json:"name,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *LogProcessorsMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + Datasources map[string]int64 `json:"datasources"` + + HubItems HubItems `json:"hub_items"` + + LastPush int64 `json:"last_push,omitempty"` + + LastUpdate int64 `json:"last_update,omitempty"` + + Name string `json:"name,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.Datasources = dataAO1.Datasources + + m.HubItems = dataAO1.HubItems + + m.LastPush = dataAO1.LastPush + + m.LastUpdate = dataAO1.LastUpdate + + m.Name = dataAO1.Name + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m LogProcessorsMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + Datasources map[string]int64 `json:"datasources"` + + HubItems HubItems `json:"hub_items"` + + LastPush int64 `json:"last_push,omitempty"` + + LastUpdate int64 `json:"last_update,omitempty"` + + Name string `json:"name,omitempty"` + } + + dataAO1.Datasources = m.Datasources + + dataAO1.HubItems = m.HubItems + + dataAO1.LastPush = m.LastPush + + dataAO1.LastUpdate = m.LastUpdate + + dataAO1.Name = m.Name + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this log processors metrics +func (m *LogProcessorsMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if err := m.validateDatasources(formats); err != nil { + res = append(res, err) + } + + if err := m.validateHubItems(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LogProcessorsMetrics) validateDatasources(formats strfmt.Registry) error { + + if err := validate.Required("datasources", "body", m.Datasources); err != nil { + return err + } + + return nil +} + +func (m *LogProcessorsMetrics) validateHubItems(formats strfmt.Registry) error { + + if err := validate.Required("hub_items", "body", m.HubItems); err != nil { + return err + } + + if m.HubItems != nil { + if err := m.HubItems.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("hub_items") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("hub_items") + } + return err + } + } + + return nil +} + +// ContextValidate validate this log processors metrics based on the context it is used +func (m *LogProcessorsMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if err := m.contextValidateHubItems(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *LogProcessorsMetrics) contextValidateHubItems(ctx context.Context, formats strfmt.Registry) error { + + if err := m.HubItems.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("hub_items") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("hub_items") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *LogProcessorsMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *LogProcessorsMetrics) UnmarshalBinary(b []byte) error { + var res LogProcessorsMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/meta.go b/pkg/models/meta.go index 6ad20856d6a..df5ae3c6285 100644 --- a/pkg/models/meta.go +++ b/pkg/models/meta.go @@ -56,6 +56,11 @@ func (m Meta) ContextValidate(ctx context.Context, formats strfmt.Registry) erro for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/models/metrics.go b/pkg/models/metrics.go index 573678d1f84..7fbb91c63e4 100644 --- a/pkg/models/metrics.go +++ b/pkg/models/metrics.go @@ -141,6 +141,11 @@ func (m *Metrics) contextValidateBouncers(ctx context.Context, formats strfmt.Re for i := 0; i < len(m.Bouncers); i++ { if m.Bouncers[i] != nil { + + if swag.IsZero(m.Bouncers[i]) { // not required + return nil + } + if err := m.Bouncers[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("bouncers" + "." + strconv.Itoa(i)) @@ -161,6 +166,11 @@ func (m *Metrics) contextValidateMachines(ctx context.Context, formats strfmt.Re for i := 0; i < len(m.Machines); i++ { if m.Machines[i] != nil { + + if swag.IsZero(m.Machines[i]) { // not required + return nil + } + if err := m.Machines[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("machines" + "." + strconv.Itoa(i)) diff --git a/pkg/models/metrics_detail_item.go b/pkg/models/metrics_detail_item.go new file mode 100644 index 00000000000..bb237884fcf --- /dev/null +++ b/pkg/models/metrics_detail_item.go @@ -0,0 +1,168 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MetricsDetailItem MetricsDetailItem +// +// swagger:model MetricsDetailItem +type MetricsDetailItem struct { + + // labels of the metric + Labels MetricsLabels `json:"labels,omitempty"` + + // name of the metric + // Required: true + // Max Length: 255 + Name *string `json:"name"` + + // unit of the metric + // Required: true + // Max Length: 255 + Unit *string `json:"unit"` + + // value of the metric + // Required: true + Value *float64 `json:"value"` +} + +// Validate validates this metrics detail item +func (m *MetricsDetailItem) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateLabels(formats); err != nil { + res = append(res, err) + } + + if err := m.validateName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUnit(formats); err != nil { + res = append(res, err) + } + + if err := m.validateValue(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsDetailItem) validateLabels(formats strfmt.Registry) error { + if swag.IsZero(m.Labels) { // not required + return nil + } + + if m.Labels != nil { + if err := m.Labels.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("labels") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("labels") + } + return err + } + } + + return nil +} + +func (m *MetricsDetailItem) validateName(formats strfmt.Registry) error { + + if err := validate.Required("name", "body", m.Name); err != nil { + return err + } + + if err := validate.MaxLength("name", "body", *m.Name, 255); err != nil { + return err + } + + return nil +} + +func (m *MetricsDetailItem) validateUnit(formats strfmt.Registry) error { + + if err := validate.Required("unit", "body", m.Unit); err != nil { + return err + } + + if err := validate.MaxLength("unit", "body", *m.Unit, 255); err != nil { + return err + } + + return nil +} + +func (m *MetricsDetailItem) validateValue(formats strfmt.Registry) error { + + if err := validate.Required("value", "body", m.Value); err != nil { + return err + } + + return nil +} + +// ContextValidate validate this metrics detail item based on the context it is used +func (m *MetricsDetailItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateLabels(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsDetailItem) contextValidateLabels(ctx context.Context, formats strfmt.Registry) error { + + if swag.IsZero(m.Labels) { // not required + return nil + } + + if err := m.Labels.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("labels") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("labels") + } + return err + } + + return nil +} + +// MarshalBinary interface implementation +func (m *MetricsDetailItem) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MetricsDetailItem) UnmarshalBinary(b []byte) error { + var res MetricsDetailItem + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/metrics_labels.go b/pkg/models/metrics_labels.go new file mode 100644 index 00000000000..176a15cce24 --- /dev/null +++ b/pkg/models/metrics_labels.go @@ -0,0 +1,42 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/validate" +) + +// MetricsLabels MetricsLabels +// +// swagger:model MetricsLabels +type MetricsLabels map[string]string + +// Validate validates this metrics labels +func (m MetricsLabels) Validate(formats strfmt.Registry) error { + var res []error + + for k := range m { + + if err := validate.MaxLength(k, "body", m[k], 255); err != nil { + return err + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validates this metrics labels based on context it is used +func (m MetricsLabels) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} diff --git a/pkg/models/metrics_meta.go b/pkg/models/metrics_meta.go new file mode 100644 index 00000000000..b021617e4d9 --- /dev/null +++ b/pkg/models/metrics_meta.go @@ -0,0 +1,88 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MetricsMeta MetricsMeta +// +// swagger:model MetricsMeta +type MetricsMeta struct { + + // UTC timestamp of the current time + // Required: true + UtcNowTimestamp *int64 `json:"utc_now_timestamp"` + + // Size, in seconds, of the window used to compute the metric + // Required: true + WindowSizeSeconds *int64 `json:"window_size_seconds"` +} + +// Validate validates this metrics meta +func (m *MetricsMeta) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateUtcNowTimestamp(formats); err != nil { + res = append(res, err) + } + + if err := m.validateWindowSizeSeconds(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MetricsMeta) validateUtcNowTimestamp(formats strfmt.Registry) error { + + if err := validate.Required("utc_now_timestamp", "body", m.UtcNowTimestamp); err != nil { + return err + } + + return nil +} + +func (m *MetricsMeta) validateWindowSizeSeconds(formats strfmt.Registry) error { + + if err := validate.Required("window_size_seconds", "body", m.WindowSizeSeconds); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this metrics meta based on context it is used +func (m *MetricsMeta) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *MetricsMeta) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MetricsMeta) UnmarshalBinary(b []byte) error { + var res MetricsMeta + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/o_sversion.go b/pkg/models/o_sversion.go new file mode 100644 index 00000000000..8f1f43ea9cc --- /dev/null +++ b/pkg/models/o_sversion.go @@ -0,0 +1,98 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// OSversion OSversion +// +// swagger:model OSversion +type OSversion struct { + + // name of the OS + // Required: true + // Max Length: 255 + Name *string `json:"name"` + + // version of the OS + // Required: true + // Max Length: 255 + Version *string `json:"version"` +} + +// Validate validates this o sversion +func (m *OSversion) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateVersion(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *OSversion) validateName(formats strfmt.Registry) error { + + if err := validate.Required("name", "body", m.Name); err != nil { + return err + } + + if err := validate.MaxLength("name", "body", *m.Name, 255); err != nil { + return err + } + + return nil +} + +func (m *OSversion) validateVersion(formats strfmt.Registry) error { + + if err := validate.Required("version", "body", m.Version); err != nil { + return err + } + + if err := validate.MaxLength("version", "body", *m.Version, 255); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this o sversion based on context it is used +func (m *OSversion) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *OSversion) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *OSversion) UnmarshalBinary(b []byte) error { + var res OSversion + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/remediation_components_metrics.go b/pkg/models/remediation_components_metrics.go new file mode 100644 index 00000000000..ba3845d872a --- /dev/null +++ b/pkg/models/remediation_components_metrics.go @@ -0,0 +1,139 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// RemediationComponentsMetrics RemediationComponentsMetrics +// +// swagger:model RemediationComponentsMetrics +type RemediationComponentsMetrics struct { + BaseMetrics + + // last pull date + LastPull int64 `json:"last_pull,omitempty"` + + // name of the remediation component + Name string `json:"name,omitempty"` + + // type of the remediation component + Type string `json:"type,omitempty"` +} + +// UnmarshalJSON unmarshals this object from a JSON structure +func (m *RemediationComponentsMetrics) UnmarshalJSON(raw []byte) error { + // AO0 + var aO0 BaseMetrics + if err := swag.ReadJSON(raw, &aO0); err != nil { + return err + } + m.BaseMetrics = aO0 + + // AO1 + var dataAO1 struct { + LastPull int64 `json:"last_pull,omitempty"` + + Name string `json:"name,omitempty"` + + Type string `json:"type,omitempty"` + } + if err := swag.ReadJSON(raw, &dataAO1); err != nil { + return err + } + + m.LastPull = dataAO1.LastPull + + m.Name = dataAO1.Name + + m.Type = dataAO1.Type + + return nil +} + +// MarshalJSON marshals this object to a JSON structure +func (m RemediationComponentsMetrics) MarshalJSON() ([]byte, error) { + _parts := make([][]byte, 0, 2) + + aO0, err := swag.WriteJSON(m.BaseMetrics) + if err != nil { + return nil, err + } + _parts = append(_parts, aO0) + var dataAO1 struct { + LastPull int64 `json:"last_pull,omitempty"` + + Name string `json:"name,omitempty"` + + Type string `json:"type,omitempty"` + } + + dataAO1.LastPull = m.LastPull + + dataAO1.Name = m.Name + + dataAO1.Type = m.Type + + jsonDataAO1, errAO1 := swag.WriteJSON(dataAO1) + if errAO1 != nil { + return nil, errAO1 + } + _parts = append(_parts, jsonDataAO1) + return swag.ConcatJSON(_parts...), nil +} + +// Validate validates this remediation components metrics +func (m *RemediationComponentsMetrics) Validate(formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.Validate(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validate this remediation components metrics based on the context it is used +func (m *RemediationComponentsMetrics) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + // validation for a type composition with BaseMetrics + if err := m.BaseMetrics.ContextValidate(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// MarshalBinary interface implementation +func (m *RemediationComponentsMetrics) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *RemediationComponentsMetrics) UnmarshalBinary(b []byte) error { + var res RemediationComponentsMetrics + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/success_response.go b/pkg/models/success_response.go new file mode 100644 index 00000000000..e8fc281c090 --- /dev/null +++ b/pkg/models/success_response.go @@ -0,0 +1,73 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// SuccessResponse success response +// +// success response return by the API +// +// swagger:model SuccessResponse +type SuccessResponse struct { + + // message + // Required: true + Message *string `json:"message"` +} + +// Validate validates this success response +func (m *SuccessResponse) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateMessage(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *SuccessResponse) validateMessage(formats strfmt.Registry) error { + + if err := validate.Required("message", "body", m.Message); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this success response based on context it is used +func (m *SuccessResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *SuccessResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *SuccessResponse) UnmarshalBinary(b []byte) error { + var res SuccessResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/watcher_registration_request.go b/pkg/models/watcher_registration_request.go index 8be802ea3e7..673f0d59b9e 100644 --- a/pkg/models/watcher_registration_request.go +++ b/pkg/models/watcher_registration_request.go @@ -27,6 +27,11 @@ type WatcherRegistrationRequest struct { // Required: true // Format: password Password *strfmt.Password `json:"password"` + + // registration token + // Max Length: 255 + // Min Length: 32 + RegistrationToken string `json:"registration_token,omitempty"` } // Validate validates this watcher registration request @@ -41,6 +46,10 @@ func (m *WatcherRegistrationRequest) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateRegistrationToken(formats); err != nil { + res = append(res, err) + } + if len(res) > 0 { return errors.CompositeValidationError(res...) } @@ -69,6 +78,22 @@ func (m *WatcherRegistrationRequest) validatePassword(formats strfmt.Registry) e return nil } +func (m *WatcherRegistrationRequest) validateRegistrationToken(formats strfmt.Registry) error { + if swag.IsZero(m.RegistrationToken) { // not required + return nil + } + + if err := validate.MinLength("registration_token", "body", m.RegistrationToken, 32); err != nil { + return err + } + + if err := validate.MaxLength("registration_token", "body", m.RegistrationToken, 255); err != nil { + return err + } + + return nil +} + // ContextValidate validates this watcher registration request based on context it is used func (m *WatcherRegistrationRequest) ContextValidate(ctx context.Context, formats strfmt.Registry) error { return nil diff --git a/pkg/modelscapi/add_signals_request.go b/pkg/modelscapi/add_signals_request.go index 62fe590cb79..7bfe6ae80e0 100644 --- a/pkg/modelscapi/add_signals_request.go +++ b/pkg/modelscapi/add_signals_request.go @@ -56,6 +56,11 @@ func (m AddSignalsRequest) ContextValidate(ctx context.Context, formats strfmt.R for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item.go b/pkg/modelscapi/add_signals_request_item.go index f9c865b4c68..5f63b542d5a 100644 --- a/pkg/modelscapi/add_signals_request_item.go +++ b/pkg/modelscapi/add_signals_request_item.go @@ -65,6 +65,9 @@ type AddSignalsRequestItem struct { // stop at // Required: true StopAt *string `json:"stop_at"` + + // UUID of the alert + UUID string `json:"uuid,omitempty"` } // Validate validates this add signals request item @@ -257,6 +260,11 @@ func (m *AddSignalsRequestItem) contextValidateContext(ctx context.Context, form for i := 0; i < len(m.Context); i++ { if m.Context[i] != nil { + + if swag.IsZero(m.Context[i]) { // not required + return nil + } + if err := m.Context[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("context" + "." + strconv.Itoa(i)) @@ -289,6 +297,7 @@ func (m *AddSignalsRequestItem) contextValidateDecisions(ctx context.Context, fo func (m *AddSignalsRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/add_signals_request_item_decisions.go b/pkg/modelscapi/add_signals_request_item_decisions.go index 54e123ab3f8..11ed27a496d 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions.go +++ b/pkg/modelscapi/add_signals_request_item_decisions.go @@ -54,6 +54,11 @@ func (m AddSignalsRequestItemDecisions) ContextValidate(ctx context.Context, for for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item_decisions_item.go b/pkg/modelscapi/add_signals_request_item_decisions_item.go index 34dfeb5bce5..797c517e33f 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions_item.go +++ b/pkg/modelscapi/add_signals_request_item_decisions_item.go @@ -49,6 +49,9 @@ type AddSignalsRequestItemDecisionsItem struct { // until Until string `json:"until,omitempty"` + // UUID of the decision + UUID string `json:"uuid,omitempty"` + // the value of the decision scope : an IP, a range, a username, etc // Required: true Value *string `json:"value"` diff --git a/pkg/modelscapi/centralapi_swagger.yaml b/pkg/modelscapi/centralapi_swagger.yaml new file mode 100644 index 00000000000..bd695894f2b --- /dev/null +++ b/pkg/modelscapi/centralapi_swagger.yaml @@ -0,0 +1,875 @@ +swagger: "2.0" +info: + description: + "API to manage machines using [crowdsec](https://github.com/crowdsecurity/crowdsec)\ + \ and bouncers.\n" + version: "2023-01-23T11:16:39Z" + title: "prod-capi-v3" + contact: + name: "Crowdsec team" + url: "https://github.com/crowdsecurity/crowdsec" + email: "support@crowdsec.net" +host: "api.crowdsec.net" +basePath: "/v3" +tags: + - name: "watchers" + description: "Operations about watchers: crowdsec & cscli" + - name: "bouncers" + description: "Operations about decisions : bans, captcha, rate-limit etc." +schemes: + - "https" +paths: + /decisions/delete: + post: + tags: + - "watchers" + summary: "delete decisions" + description: "delete provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsDeleteRequest" + required: true + schema: + $ref: "#/definitions/DecisionsDeleteRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /decisions/stream: + get: + tags: + - "bouncers" + - "watchers" + summary: "returns list of top decisions" + description: "returns list of top decisions to add or delete" + produces: + - "application/json" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/GetDecisionsStreamResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" + /decisions/sync: + post: + tags: + - "watchers" + summary: "sync decisions" + description: "sync provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsSyncRequest" + required: true + schema: + $ref: "#/definitions/DecisionsSyncRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /metrics: + post: + tags: + - "watchers" + summary: "receive metrics about enrolled machines and bouncers in APIL" + description: "receive metrics about enrolled machines and bouncers in APIL" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "MetricsRequest" + required: true + schema: + $ref: "#/definitions/MetricsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /signals: + post: + tags: + - "watchers" + summary: "Push signals" + description: "to push signals" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "AddSignalsRequest" + required: true + schema: + $ref: "#/definitions/AddSignalsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers: + post: + tags: + - "watchers" + summary: "Register watcher" + description: "Register a watcher" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "RegisterRequest" + required: true + schema: + $ref: "#/definitions/RegisterRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/enroll: + post: + tags: + - "watchers" + summary: "watcher enrollment" + description: "watcher enrollment : enroll watcher to crowdsec backoffice account" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "EnrollRequest" + required: true + schema: + $ref: "#/definitions/EnrollRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers/login: + post: + tags: + - "watchers" + summary: "watcher login" + description: "Sign-in to get a valid token" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "LoginRequest" + required: true + schema: + $ref: "#/definitions/LoginRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/LoginResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/reset: + post: + tags: + - "watchers" + summary: "Reset Password" + description: "to reset a watcher password" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "ResetPasswordRequest" + required: true + schema: + $ref: "#/definitions/ResetPasswordRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" +securityDefinitions: + UserPoolAuthorizer: + type: "apiKey" + name: "Authorization" + in: "header" + x-amazon-apigateway-authtype: "cognito_user_pools" +definitions: + DecisionsDeleteRequest: + title: "delete decisions" + type: "array" + description: "delete decision model" + items: + $ref: "#/definitions/DecisionsDeleteRequestItem" + DecisionsSyncRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + created_at: + type: "string" + machine_id: + type: "string" + decisions: + $ref: "#/definitions/DecisionsSyncRequestItemDecisions" + source: + $ref: "#/definitions/DecisionsSyncRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + start_at: + type: "string" + stop_at: + type: "string" + title: "Signal" + AddSignalsRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + created_at: + type: "string" + machine_id: + type: "string" + source: + $ref: "#/definitions/AddSignalsRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + uuid: + type: "string" + description: "UUID of the alert" + start_at: + type: "string" + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + context: + type: "array" + items: + type: "object" + properties: + value: + type: "string" + key: + type: "string" + decisions: + $ref: "#/definitions/AddSignalsRequestItemDecisions" + stop_at: + type: "string" + title: "Signal" + DecisionsSyncRequest: + title: "sync decisions request" + type: "array" + description: "sync decision model" + items: + $ref: "#/definitions/DecisionsSyncRequestItem" + LoginRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + scenarios: + type: "array" + description: "all scenarios installed" + items: + type: "string" + title: "login request" + description: "Login request model" + GetDecisionsStreamResponseNewItem: + type: "object" + required: + - "scenario" + - "scope" + - "decisions" + properties: + scenario: + type: "string" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: object + required: + - value + - duration + properties: + duration: + type: "string" + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "New Decisions" + GetDecisionsStreamResponseDeletedItem: + type: object + required: + - scope + - decisions + properties: + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: string + BlocklistLink: + type: object + required: + - name + - url + - remediation + - scope + - duration + properties: + name: + type: string + description: "the name of the blocklist" + url: + type: string + description: "the url from which the blocklist content can be downloaded" + remediation: + type: string + description: "the remediation that should be used for the blocklist" + scope: + type: string + description: "the scope of decisions in the blocklist" + duration: + type: string + AddSignalsRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + uuid: + type: "string" + description: "UUID of the decision" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + EnrollRequest: + type: "object" + required: + - "attachment_key" + properties: + name: + type: "string" + description: "The name that will be display in the console for the instance" + overwrite: + type: "boolean" + description: "To force enroll the instance" + attachment_key: + type: "string" + description: + "attachment_key is generated in your crowdsec backoffice account\ + \ and allows you to enroll your machines to your BO account" + pattern: "^[a-zA-Z0-9]+$" + tags: + type: "array" + description: "Tags to apply on the console for the instance" + items: + type: "string" + title: "enroll request" + description: "enroll request model" + ResetPasswordRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + title: "resetPassword" + description: "ResetPassword request model" + MetricsRequestBouncersItem: + type: "object" + properties: + last_pull: + type: "string" + description: "last bouncer pull date" + custom_name: + type: "string" + description: "bouncer name" + name: + type: "string" + description: "bouncer type (firewall, php...)" + version: + type: "string" + description: "bouncer version" + title: "MetricsBouncerInfo" + AddSignalsRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + DecisionsSyncRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/DecisionsSyncRequestItemDecisionsItem" + RegisterRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + pattern: "^[a-zA-Z0-9]+$" + title: "register request" + description: "Register request model" + SuccessResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "message" + title: "success response" + description: "success response return by the API" + LoginResponse: + type: "object" + properties: + code: + type: "integer" + expire: + type: "string" + token: + type: "string" + title: "login response" + description: "Login request model" + DecisionsSyncRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + GetDecisionsStreamResponse: + type: "object" + properties: + new: + $ref: "#/definitions/GetDecisionsStreamResponseNew" + deleted: + $ref: "#/definitions/GetDecisionsStreamResponseDeleted" + links: + $ref: "#/definitions/GetDecisionsStreamResponseLinks" + title: "get decisions stream response" + description: "get decision response model" + DecisionsSyncRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + AddSignalsRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/AddSignalsRequestItemDecisionsItem" + MetricsRequestMachinesItem: + type: "object" + properties: + last_update: + type: "string" + description: "last agent update date" + name: + type: "string" + description: "agent name" + last_push: + type: "string" + description: "last agent push date" + version: + type: "string" + description: "agent version" + title: "MetricsAgentInfo" + MetricsRequest: + type: "object" + required: + - "bouncers" + - "machines" + properties: + bouncers: + type: "array" + items: + $ref: "#/definitions/MetricsRequestBouncersItem" + machines: + type: "array" + items: + $ref: "#/definitions/MetricsRequestMachinesItem" + title: "metrics" + description: "push metrics model" + ErrorResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "Error message" + errors: + type: "string" + description: "more detail on individual errors" + title: "error response" + description: "error response return by the API" + AddSignalsRequest: + title: "add signals request" + type: "array" + description: "All signals request model" + items: + $ref: "#/definitions/AddSignalsRequestItem" + DecisionsDeleteRequestItem: + type: "string" + title: "decisionsIDs" + GetDecisionsStreamResponseNew: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseNewItem" + GetDecisionsStreamResponseDeleted: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseDeletedItem" + GetDecisionsStreamResponseLinks: + title: "Decisions list" + type: "object" + properties: + blocklists: + type: array + items: + $ref: "#/definitions/BlocklistLink" + diff --git a/pkg/modelscapi/decisions_delete_request.go b/pkg/modelscapi/decisions_delete_request.go index e8718835027..0c93558adf1 100644 --- a/pkg/modelscapi/decisions_delete_request.go +++ b/pkg/modelscapi/decisions_delete_request.go @@ -11,6 +11,7 @@ import ( "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // DecisionsDeleteRequest delete decisions @@ -49,6 +50,10 @@ func (m DecisionsDeleteRequest) ContextValidate(ctx context.Context, formats str for i := 0; i < len(m); i++ { + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request.go b/pkg/modelscapi/decisions_sync_request.go index e3a95162519..c087d39ff62 100644 --- a/pkg/modelscapi/decisions_sync_request.go +++ b/pkg/modelscapi/decisions_sync_request.go @@ -56,6 +56,11 @@ func (m DecisionsSyncRequest) ContextValidate(ctx context.Context, formats strfm for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request_item.go b/pkg/modelscapi/decisions_sync_request_item.go index 5139ea2de4b..460fe4d430e 100644 --- a/pkg/modelscapi/decisions_sync_request_item.go +++ b/pkg/modelscapi/decisions_sync_request_item.go @@ -231,6 +231,7 @@ func (m *DecisionsSyncRequestItem) contextValidateDecisions(ctx context.Context, func (m *DecisionsSyncRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/decisions_sync_request_item_decisions.go b/pkg/modelscapi/decisions_sync_request_item_decisions.go index 76316e43c5e..bdc8e77e2b6 100644 --- a/pkg/modelscapi/decisions_sync_request_item_decisions.go +++ b/pkg/modelscapi/decisions_sync_request_item_decisions.go @@ -54,6 +54,11 @@ func (m DecisionsSyncRequestItemDecisions) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/generate.go b/pkg/modelscapi/generate.go new file mode 100644 index 00000000000..66dc2a34b7e --- /dev/null +++ b/pkg/modelscapi/generate.go @@ -0,0 +1,4 @@ +package modelscapi + +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./centralapi_swagger.yaml --target=../ --model-package=modelscapi + diff --git a/pkg/modelscapi/get_decisions_stream_response.go b/pkg/modelscapi/get_decisions_stream_response.go index af19b85c4d3..5ebf29c5d93 100644 --- a/pkg/modelscapi/get_decisions_stream_response.go +++ b/pkg/modelscapi/get_decisions_stream_response.go @@ -144,6 +144,11 @@ func (m *GetDecisionsStreamResponse) contextValidateDeleted(ctx context.Context, func (m *GetDecisionsStreamResponse) contextValidateLinks(ctx context.Context, formats strfmt.Registry) error { if m.Links != nil { + + if swag.IsZero(m.Links) { // not required + return nil + } + if err := m.Links.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("links") diff --git a/pkg/modelscapi/get_decisions_stream_response_deleted.go b/pkg/modelscapi/get_decisions_stream_response_deleted.go index d218bf87e4e..78292860f22 100644 --- a/pkg/modelscapi/get_decisions_stream_response_deleted.go +++ b/pkg/modelscapi/get_decisions_stream_response_deleted.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseDeleted) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_links.go b/pkg/modelscapi/get_decisions_stream_response_links.go index 85cc9af9b48..6b9054574f1 100644 --- a/pkg/modelscapi/get_decisions_stream_response_links.go +++ b/pkg/modelscapi/get_decisions_stream_response_links.go @@ -82,6 +82,11 @@ func (m *GetDecisionsStreamResponseLinks) contextValidateBlocklists(ctx context. for i := 0; i < len(m.Blocklists); i++ { if m.Blocklists[i] != nil { + + if swag.IsZero(m.Blocklists[i]) { // not required + return nil + } + if err := m.Blocklists[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("blocklists" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new.go b/pkg/modelscapi/get_decisions_stream_response_new.go index e9525bf6fa7..8e09f1b20e7 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new.go +++ b/pkg/modelscapi/get_decisions_stream_response_new.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseNew) ContextValidate(ctx context.Context, form for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new_item.go b/pkg/modelscapi/get_decisions_stream_response_new_item.go index a3592d0ab61..77cc06732ce 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new_item.go +++ b/pkg/modelscapi/get_decisions_stream_response_new_item.go @@ -119,6 +119,11 @@ func (m *GetDecisionsStreamResponseNewItem) contextValidateDecisions(ctx context for i := 0; i < len(m.Decisions); i++ { if m.Decisions[i] != nil { + + if swag.IsZero(m.Decisions[i]) { // not required + return nil + } + if err := m.Decisions[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("decisions" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/metrics_request.go b/pkg/modelscapi/metrics_request.go index d5b7d058fc1..5d663cf1750 100644 --- a/pkg/modelscapi/metrics_request.go +++ b/pkg/modelscapi/metrics_request.go @@ -126,6 +126,11 @@ func (m *MetricsRequest) contextValidateBouncers(ctx context.Context, formats st for i := 0; i < len(m.Bouncers); i++ { if m.Bouncers[i] != nil { + + if swag.IsZero(m.Bouncers[i]) { // not required + return nil + } + if err := m.Bouncers[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("bouncers" + "." + strconv.Itoa(i)) @@ -146,6 +151,11 @@ func (m *MetricsRequest) contextValidateMachines(ctx context.Context, formats st for i := 0; i < len(m.Machines); i++ { if m.Machines[i] != nil { + + if swag.IsZero(m.Machines[i]) { // not required + return nil + } + if err := m.Machines[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("machines" + "." + strconv.Itoa(i)) diff --git a/pkg/parser/README.md b/pkg/parser/README.md index 62a56e61820..0fcccc811e4 100644 --- a/pkg/parser/README.md +++ b/pkg/parser/README.md @@ -45,7 +45,7 @@ statics: > `filter: "Line.Src endsWith '/foobar'"` - - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) that will be evaluated against the runtime of a line (`Event`) + - *optional* `filter` : an [expression](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) that will be evaluated against the runtime of a line (`Event`) - if the `filter` is present and returns false, node is not evaluated - if `filter` is absent or present and returns true, node is evaluated diff --git a/pkg/parser/enrich.go b/pkg/parser/enrich.go index 5180b9a5fb9..661410d20d3 100644 --- a/pkg/parser/enrich.go +++ b/pkg/parser/enrich.go @@ -7,7 +7,7 @@ import ( ) /* should be part of a package shared with enrich/geoip.go */ -type EnrichFunc func(string, *types.Event, interface{}, *log.Entry) (map[string]string, error) +type EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error) type InitFunc func(map[string]string) (interface{}, error) type EnricherCtx struct { @@ -16,59 +16,42 @@ type EnricherCtx struct { type Enricher struct { Name string - InitFunc InitFunc EnrichFunc EnrichFunc - Ctx interface{} } /* mimic plugin loading */ -func Loadplugin(path string) (EnricherCtx, error) { +func Loadplugin() (EnricherCtx, error) { enricherCtx := EnricherCtx{} enricherCtx.Registered = make(map[string]*Enricher) - enricherConfig := map[string]string{"datadir": path} - EnrichersList := []*Enricher{ { Name: "GeoIpCity", - InitFunc: GeoIPCityInit, EnrichFunc: GeoIpCity, }, { Name: "GeoIpASN", - InitFunc: GeoIPASNInit, EnrichFunc: GeoIpASN, }, { Name: "IpToRange", - InitFunc: IpToRangeInit, EnrichFunc: IpToRange, }, { Name: "reverse_dns", - InitFunc: reverseDNSInit, EnrichFunc: reverse_dns, }, { Name: "ParseDate", - InitFunc: parseDateInit, EnrichFunc: ParseDate, }, { Name: "UnmarshalJSON", - InitFunc: unmarshalInit, EnrichFunc: unmarshalJSON, }, } for _, enricher := range EnrichersList { - log.Debugf("Initiating enricher '%s'", enricher.Name) - pluginCtx, err := enricher.InitFunc(enricherConfig) - if err != nil { - log.Errorf("unable to register plugin '%s': %v", enricher.Name, err) - continue - } - enricher.Ctx = pluginCtx log.Infof("Successfully registered enricher '%s'", enricher.Name) enricherCtx.Registered[enricher.Name] = enricher } diff --git a/pkg/parser/enrich_date.go b/pkg/parser/enrich_date.go index 20828af9037..40c8de39da5 100644 --- a/pkg/parser/enrich_date.go +++ b/pkg/parser/enrich_date.go @@ -18,7 +18,7 @@ func parseDateWithFormat(date, format string) (string, time.Time) { } retstr, err := t.MarshalText() if err != nil { - log.Warningf("Failed marshaling '%v'", t) + log.Warningf("Failed to serialize '%v'", t) return "", time.Time{} } return string(retstr), t @@ -56,7 +56,7 @@ func GenDateParse(date string) (string, time.Time) { return "", time.Time{} } -func ParseDate(in string, p *types.Event, x interface{}, plog *log.Entry) (map[string]string, error) { +func ParseDate(in string, p *types.Event, plog *log.Entry) (map[string]string, error) { var ret = make(map[string]string) var strDate string @@ -98,14 +98,10 @@ func ParseDate(in string, p *types.Event, x interface{}, plog *log.Entry) (map[s now := time.Now().UTC() retstr, err := now.MarshalText() if err != nil { - plog.Warning("Failed marshaling current time") + plog.Warning("Failed to serialize current time") return ret, err } ret["MarshaledTime"] = string(retstr) return ret, nil } - -func parseDateInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/enrich_date_test.go b/pkg/parser/enrich_date_test.go index 084ded52573..930633feb35 100644 --- a/pkg/parser/enrich_date_test.go +++ b/pkg/parser/enrich_date_test.go @@ -42,13 +42,10 @@ func TestDateParse(t *testing.T) { }, } - logger := log.WithFields(log.Fields{ - "test": "test", - }) + logger := log.WithField("test", "test") for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { - strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, nil, logger) + strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, logger) cstest.RequireErrorContains(t, err, tt.expectedErr) if tt.expectedErr != "" { return diff --git a/pkg/parser/enrich_dns.go b/pkg/parser/enrich_dns.go index f622e6c359a..1ff5b0f4f16 100644 --- a/pkg/parser/enrich_dns.go +++ b/pkg/parser/enrich_dns.go @@ -11,7 +11,7 @@ import ( /* All plugins must export a list of function pointers for exported symbols */ //var ExportedFuncs = []string{"reverse_dns"} -func reverse_dns(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { +func reverse_dns(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { ret := make(map[string]string) if field == "" { return nil, nil @@ -25,7 +25,3 @@ func reverse_dns(field string, p *types.Event, ctx interface{}, plog *log.Entry) ret["reverse_dns"] = rets[0] return ret, nil } - -func reverseDNSInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/enrich_geoip.go b/pkg/parser/enrich_geoip.go index 0a263c82793..1756927bc4b 100644 --- a/pkg/parser/enrich_geoip.go +++ b/pkg/parser/enrich_geoip.go @@ -6,53 +6,66 @@ import ( "strconv" "github.com/oschwald/geoip2-golang" - "github.com/oschwald/maxminddb-golang" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func IpToRange(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - var dummy interface{} - ret := make(map[string]string) - +func IpToRange(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no range enrich", field) - return nil, nil - } - net, ok, err := ctx.(*maxminddb.Reader).LookupNetwork(ip, &dummy) + + r, err := exprhelpers.GeoIPRangeEnrich(field) + if err != nil { - plog.Errorf("Failed to fetch network for %s : %v", ip.String(), err) + plog.Errorf("Unable to enrich ip '%s'", field) + return nil, nil //nolint:nilerr + } + + if r == nil { + plog.Debugf("No range found for ip '%s'", field) return nil, nil } + + record, ok := r.(*net.IPNet) + if !ok { - plog.Debugf("Unable to find range of %s", ip.String()) return nil, nil } - ret["SourceRange"] = net.String() + + ret := make(map[string]string) + ret["SourceRange"] = record.String() + return ret, nil } -func GeoIpASN(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - ret := make(map[string]string) +func GeoIpASN(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no ASN enrich", ip) - return nil, nil - } - record, err := ctx.(*geoip2.Reader).ASN(ip) + r, err := exprhelpers.GeoIPASNEnrich(field) + if err != nil { - plog.Errorf("Unable to enrich ip '%s'", field) + plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr } + + if r == nil { + plog.Debugf("No ASN found for ip '%s'", field) + return nil, nil + } + + record, ok := r.(*geoip2.ASN) + + if !ok { + return nil, nil + } + + ret := make(map[string]string) + ret["ASNNumber"] = fmt.Sprintf("%d", record.AutonomousSystemNumber) ret["ASNumber"] = fmt.Sprintf("%d", record.AutonomousSystemNumber) ret["ASNOrg"] = record.AutonomousSystemOrganization @@ -62,21 +75,31 @@ func GeoIpASN(field string, p *types.Event, ctx interface{}, plog *log.Entry) (m return ret, nil } -func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { - ret := make(map[string]string) +func GeoIpCity(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { if field == "" { return nil, nil } - ip := net.ParseIP(field) - if ip == nil { - plog.Infof("Can't parse ip %s, no City enrich", ip) - return nil, nil - } - record, err := ctx.(*geoip2.Reader).City(ip) + + r, err := exprhelpers.GeoIPEnrich(field) + if err != nil { - plog.Debugf("Unable to enrich ip '%s'", ip) + plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr } + + if r == nil { + plog.Debugf("No city found for ip '%s'", field) + return nil, nil + } + + record, ok := r.(*geoip2.City) + + if !ok { + return nil, nil + } + + ret := make(map[string]string) + if record.Country.IsoCode != "" { ret["IsoCode"] = record.Country.IsoCode ret["IsInEU"] = strconv.FormatBool(record.Country.IsInEuropeanUnion) @@ -88,7 +111,7 @@ func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) ( ret["IsInEU"] = strconv.FormatBool(record.RepresentedCountry.IsInEuropeanUnion) } else { ret["IsoCode"] = "" - ret["IsInEU"] = strconv.FormatBool(false) + ret["IsInEU"] = "false" } ret["Latitude"] = fmt.Sprintf("%f", record.Location.Latitude) @@ -98,33 +121,3 @@ func GeoIpCity(field string, p *types.Event, ctx interface{}, plog *log.Entry) ( return ret, nil } - -func GeoIPCityInit(cfg map[string]string) (interface{}, error) { - dbCityReader, err := geoip2.Open(cfg["datadir"] + "/GeoLite2-City.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return dbCityReader, nil -} - -func GeoIPASNInit(cfg map[string]string) (interface{}, error) { - dbASReader, err := geoip2.Open(cfg["datadir"] + "/GeoLite2-ASN.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return dbASReader, nil -} - -func IpToRangeInit(cfg map[string]string) (interface{}, error) { - ipToRangeReader, err := maxminddb.Open(cfg["datadir"] + "/GeoLite2-ASN.mmdb") - if err != nil { - log.Debugf("couldn't open geoip : %v", err) - return nil, err - } - - return ipToRangeReader, nil -} diff --git a/pkg/parser/enrich_unmarshal.go b/pkg/parser/enrich_unmarshal.go index dce9c75d466..dbdd9d3f583 100644 --- a/pkg/parser/enrich_unmarshal.go +++ b/pkg/parser/enrich_unmarshal.go @@ -8,16 +8,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func unmarshalJSON(field string, p *types.Event, ctx interface{}, plog *log.Entry) (map[string]string, error) { +func unmarshalJSON(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { err := json.Unmarshal([]byte(p.Line.Raw), &p.Unmarshaled) if err != nil { - plog.Errorf("could not unmarshal JSON: %s", err) + plog.Errorf("could not parse JSON: %s", err) return nil, err } plog.Tracef("unmarshaled JSON: %+v", p.Unmarshaled) return nil, nil } - -func unmarshalInit(cfg map[string]string) (interface{}, error) { - return nil, nil -} diff --git a/pkg/parser/grok_pattern.go b/pkg/parser/grok_pattern.go index 5b3204a4201..9c781d47aa6 100644 --- a/pkg/parser/grok_pattern.go +++ b/pkg/parser/grok_pattern.go @@ -3,7 +3,7 @@ package parser import ( "time" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr/vm" "github.com/crowdsecurity/grokky" ) diff --git a/pkg/parser/node.go b/pkg/parser/node.go index 23ed20511c3..26046ae4fd6 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -6,9 +6,9 @@ import ( "strings" "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" "github.com/davecgh/go-spew/spew" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" @@ -22,69 +22,70 @@ import ( type Node struct { FormatVersion string `yaml:"format"` - //Enable config + runtime debug of node via config o/ + // Enable config + runtime debug of node via config o/ Debug bool `yaml:"debug,omitempty"` - //If enabled, the node (and its child) will report their own statistics + // If enabled, the node (and its child) will report their own statistics Profiling bool `yaml:"profiling,omitempty"` - //Name, author, description and reference(s) for parser pattern + // Name, author, description and reference(s) for parser pattern Name string `yaml:"name,omitempty"` Author string `yaml:"author,omitempty"` Description string `yaml:"description,omitempty"` References []string `yaml:"references,omitempty"` - //if debug is present in the node, keep its specific Logger in runtime structure + // if debug is present in the node, keep its specific Logger in runtime structure Logger *log.Entry `yaml:"-"` - //This is mostly a hack to make writing less repetitive. - //relying on stage, we know which field to parse, and we - //can also promote log to next stage on success + // This is mostly a hack to make writing less repetitive. + // relying on stage, we know which field to parse, and we + // can also promote log to next stage on success Stage string `yaml:"stage,omitempty"` - //OnSuccess allows to tag a node to be able to move log to next stage on success + // OnSuccess allows to tag a node to be able to move log to next stage on success OnSuccess string `yaml:"onsuccess,omitempty"` - rn string //this is only for us in debug, a random generated name for each node - //Filter is executed at runtime (with current log line as context) - //and must succeed or node is exited + rn string // this is only for us in debug, a random generated name for each node + // Filter is executed at runtime (with current log line as context) + // and must succeed or node is exited Filter string `yaml:"filter,omitempty"` - RunTimeFilter *vm.Program `yaml:"-" json:"-"` //the actual compiled filter - //If node has leafs, execute all of them until one asks for a 'break' + RunTimeFilter *vm.Program `yaml:"-" json:"-"` // the actual compiled filter + // If node has leafs, execute all of them until one asks for a 'break' LeavesNodes []Node `yaml:"nodes,omitempty"` - //Flag used to describe when to 'break' or return an 'error' + // Flag used to describe when to 'break' or return an 'error' EnrichFunctions EnricherCtx /* If the node is actually a leaf, it can have : grok, enrich, statics */ - //pattern_syntax are named grok patterns that are re-utilized over several grok patterns + // pattern_syntax are named grok patterns that are re-utilized over several grok patterns SubGroks yaml.MapSlice `yaml:"pattern_syntax,omitempty"` - //Holds a grok pattern + // Holds a grok pattern Grok GrokPattern `yaml:"grok,omitempty"` - //Statics can be present in any type of node and is executed last + // Statics can be present in any type of node and is executed last Statics []ExtraField `yaml:"statics,omitempty"` - //Stash allows to capture data from the log line and store it in an accessible cache + // Stash allows to capture data from the log line and store it in an accessible cache Stash []DataCapture `yaml:"stash,omitempty"` - //Whitelists + // Whitelists Whitelist Whitelist `yaml:"whitelist,omitempty"` Data []*types.DataSource `yaml:"data,omitempty"` } -func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { - - //stage is being set automagically +func (n *Node) validate(ectx EnricherCtx) error { + // stage is being set automagically if n.Stage == "" { - return fmt.Errorf("stage needs to be an existing stage") + return errors.New("stage needs to be an existing stage") } /* "" behaves like continue */ if n.OnSuccess != "continue" && n.OnSuccess != "next_stage" && n.OnSuccess != "" { return fmt.Errorf("onsuccess '%s' not continue,next_stage", n.OnSuccess) } + if n.Filter != "" && n.RunTimeFilter == nil { return fmt.Errorf("non-empty filter '%s' was not compiled", n.Filter) } if n.Grok.RunTimeRegexp != nil || n.Grok.TargetField != "" { if n.Grok.TargetField == "" && n.Grok.ExpValue == "" { - return fmt.Errorf("grok requires 'expression' or 'apply_on'") + return errors.New("grok requires 'expression' or 'apply_on'") } + if n.Grok.RegexpName == "" && n.Grok.RegexpValue == "" { - return fmt.Errorf("grok needs 'pattern' or 'name'") + return errors.New("grok needs 'pattern' or 'name'") } } @@ -93,6 +94,7 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if static.ExpValue == "" { return fmt.Errorf("static %d : when method is set, expression must be present", idx) } + if _, ok := ectx.Registered[static.Method]; !ok { log.Warningf("the method '%s' doesn't exist or the plugin has not been initialized", static.Method) } @@ -100,6 +102,7 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if static.Meta == "" && static.Parsed == "" && static.TargetByName == "" { return fmt.Errorf("static %d : at least one of meta/event/target must be set", idx) } + if static.Value == "" && static.RunTimeValue == nil { return fmt.Errorf("static %d value or expression must be set", idx) } @@ -110,72 +113,76 @@ func (n *Node) validate(pctx *UnixParserCtx, ectx EnricherCtx) error { if stash.Name == "" { return fmt.Errorf("stash %d : name must be set", idx) } + if stash.Value == "" { return fmt.Errorf("stash %s : value expression must be set", stash.Name) } + if stash.Key == "" { return fmt.Errorf("stash %s : key expression must be set", stash.Name) } + if stash.TTL == "" { return fmt.Errorf("stash %s : ttl must be set", stash.Name) } + if stash.Strategy == "" { stash.Strategy = "LRU" } - //should be configurable + // should be configurable if stash.MaxMapSize == 0 { stash.MaxMapSize = 100 } } + return nil } -func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[string]interface{}) (bool, error) { - var NodeState bool - var NodeHasOKGrok bool +func (n *Node) processFilter(cachedExprEnv map[string]interface{}) (bool, error) { clog := n.Logger + if n.RunTimeFilter == nil { + clog.Tracef("Node has not filter, enter") + return true, nil + } - cachedExprEnv := expressionEnv + // Evaluate node's filter + output, err := exprhelpers.Run(n.RunTimeFilter, cachedExprEnv, clog, n.Debug) + if err != nil { + clog.Warningf("failed to run filter : %v", err) + clog.Debugf("Event leaving node : ko") - clog.Tracef("Event entering node") - if n.RunTimeFilter != nil { - //Evaluate node's filter - output, err := exprhelpers.Run(n.RunTimeFilter, cachedExprEnv, clog, n.Debug) - if err != nil { - clog.Warningf("failed to run filter : %v", err) - clog.Debugf("Event leaving node : ko") - return false, nil - } + return false, nil + } - switch out := output.(type) { - case bool: - if !out { - clog.Debugf("Event leaving node : ko (failed filter)") - return false, nil - } - default: - clog.Warningf("Expr '%s' returned non-bool, abort : %T", n.Filter, output) - clog.Debugf("Event leaving node : ko") + switch out := output.(type) { + case bool: + if !out { + clog.Debugf("Event leaving node : ko (failed filter)") return false, nil } - NodeState = true - } else { - clog.Tracef("Node has not filter, enter") - NodeState = true - } + default: + clog.Warningf("Expr '%s' returned non-bool, abort : %T", n.Filter, output) + clog.Debugf("Event leaving node : ko") - if n.Name != "" { - NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() + return false, nil } - exprErr := error(nil) - isWhitelisted := n.CheckIPsWL(p.ParseIPSources()) + + return true, nil +} + +func (n *Node) processWhitelist(cachedExprEnv map[string]interface{}, p *types.Event) (bool, error) { + var exprErr error + + isWhitelisted := n.CheckIPsWL(p) if !isWhitelisted { - isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv) + isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv, p) } + if exprErr != nil { // Previous code returned nil if there was an error, so we keep this behavior return false, nil //nolint:nilerr } + if isWhitelisted && !p.Whitelisted { p.Whitelisted = true p.WhitelistReason = n.Whitelist.Reason @@ -185,95 +192,145 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri for k := range p.Overflow.Sources { ips = append(ips, k) } - clog.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason) + + n.Logger.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason) + p.Overflow.Whitelisted = true } } - //Process grok if present, should be exclusive with nodes :) + return isWhitelisted, nil +} + +func (n *Node) processGrok(p *types.Event, cachedExprEnv map[string]any) (bool, bool, error) { + // Process grok if present, should be exclusive with nodes :) + clog := n.Logger + var NodeHasOKGrok bool gstr := "" - if n.Grok.RunTimeRegexp != nil { - clog.Tracef("Processing grok pattern : %s : %p", n.Grok.RegexpName, n.Grok.RunTimeRegexp) - //for unparsed, parsed etc. set sensible defaults to reduce user hassle - if n.Grok.TargetField != "" { - //it's a hack to avoid using real reflect - if n.Grok.TargetField == "Line.Raw" { - gstr = p.Line.Raw - } else if val, ok := p.Parsed[n.Grok.TargetField]; ok { - gstr = val - } else { - clog.Debugf("(%s) target field '%s' doesn't exist in %v", n.rn, n.Grok.TargetField, p.Parsed) - NodeState = false - } - } else if n.Grok.RunTimeValue != nil { - output, err := exprhelpers.Run(n.Grok.RunTimeValue, cachedExprEnv, clog, n.Debug) - if err != nil { - clog.Warningf("failed to run RunTimeValue : %v", err) - NodeState = false - } - switch out := output.(type) { - case string: - gstr = out - case int: - gstr = fmt.Sprintf("%d", out) - case float64, float32: - gstr = fmt.Sprintf("%f", out) - default: - clog.Errorf("unexpected return type for RunTimeValue : %T", output) - } - } - var groklabel string - if n.Grok.RegexpName == "" { - groklabel = fmt.Sprintf("%5.5s...", n.Grok.RegexpValue) - } else { - groklabel = n.Grok.RegexpName - } - grok := n.Grok.RunTimeRegexp.Parse(gstr) - if len(grok) > 0 { - /*tag explicitly that the *current* node had a successful grok pattern. it's important to know success state*/ - NodeHasOKGrok = true - clog.Debugf("+ Grok '%s' returned %d entries to merge in Parsed", groklabel, len(grok)) - //We managed to grok stuff, merged into parse - for k, v := range grok { - clog.Debugf("\t.Parsed['%s'] = '%s'", k, v) - p.Parsed[k] = v - } - // if the grok succeed, process associated statics - err := n.ProcessStatics(n.Grok.Statics, p) - if err != nil { - clog.Errorf("(%s) Failed to process statics : %v", n.rn, err) - return false, err - } + if n.Grok.RunTimeRegexp == nil { + clog.Tracef("! No grok pattern : %p", n.Grok.RunTimeRegexp) + return true, false, nil + } + + clog.Tracef("Processing grok pattern : %s : %p", n.Grok.RegexpName, n.Grok.RunTimeRegexp) + // for unparsed, parsed etc. set sensible defaults to reduce user hassle + if n.Grok.TargetField != "" { + // it's a hack to avoid using real reflect + if n.Grok.TargetField == "Line.Raw" { + gstr = p.Line.Raw + } else if val, ok := p.Parsed[n.Grok.TargetField]; ok { + gstr = val } else { - //grok failed, node failed - clog.Debugf("+ Grok '%s' didn't return data on '%s'", groklabel, gstr) - NodeState = false + clog.Debugf("(%s) target field '%s' doesn't exist in %v", n.rn, n.Grok.TargetField, p.Parsed) + return false, false, nil + } + } else if n.Grok.RunTimeValue != nil { + output, err := exprhelpers.Run(n.Grok.RunTimeValue, cachedExprEnv, clog, n.Debug) + if err != nil { + clog.Warningf("failed to run RunTimeValue : %v", err) + return false, false, nil + } + + switch out := output.(type) { + case string: + gstr = out + case int: + gstr = fmt.Sprintf("%d", out) + case float64, float32: + gstr = fmt.Sprintf("%f", out) + default: + clog.Errorf("unexpected return type for RunTimeValue : %T", output) } + } + var groklabel string + if n.Grok.RegexpName == "" { + groklabel = fmt.Sprintf("%5.5s...", n.Grok.RegexpValue) } else { - clog.Tracef("! No grok pattern : %p", n.Grok.RunTimeRegexp) + groklabel = n.Grok.RegexpName + } + + grok := n.Grok.RunTimeRegexp.Parse(gstr) + + if len(grok) == 0 { + // grok failed, node failed + clog.Debugf("+ Grok '%s' didn't return data on '%s'", groklabel, gstr) + return false, false, nil + } + + /*tag explicitly that the *current* node had a successful grok pattern. it's important to know success state*/ + NodeHasOKGrok = true + + clog.Debugf("+ Grok '%s' returned %d entries to merge in Parsed", groklabel, len(grok)) + // We managed to grok stuff, merged into parse + for k, v := range grok { + clog.Debugf("\t.Parsed['%s'] = '%s'", k, v) + p.Parsed[k] = v + } + // if the grok succeed, process associated statics + err := n.ProcessStatics(n.Grok.Statics, p) + if err != nil { + clog.Errorf("(%s) Failed to process statics : %v", n.rn, err) + return false, false, err + } + + return true, NodeHasOKGrok, nil +} + +func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[string]interface{}) (bool, error) { + clog := n.Logger + + cachedExprEnv := expressionEnv + + clog.Tracef("Event entering node") + + NodeState, err := n.processFilter(cachedExprEnv) + if err != nil { + return false, err } - //Process the stash (data collection) if : a grok was present and succeeded, or if there is no grok + if !NodeState { + return false, nil + } + + if n.Name != "" { + NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() + } + + isWhitelisted, err := n.processWhitelist(cachedExprEnv, p) + if err != nil { + return false, err + } + + NodeState, NodeHasOKGrok, err := n.processGrok(p, cachedExprEnv) + if err != nil { + return false, err + } + + // Process the stash (data collection) if : a grok was present and succeeded, or if there is no grok if NodeHasOKGrok || n.Grok.RunTimeRegexp == nil { for idx, stash := range n.Stash { - var value string - var key string + var ( + key string + value string + ) + if stash.ValueExpression == nil { clog.Warningf("Stash %d has no value expression, skipping", idx) continue } + if stash.KeyExpression == nil { clog.Warningf("Stash %d has no key expression, skipping", idx) continue } - //collect the data + // collect the data output, err := exprhelpers.Run(stash.ValueExpression, cachedExprEnv, clog, n.Debug) if err != nil { clog.Warningf("Error while running stash val expression : %v", err) } - //can we expect anything else than a string ? + // can we expect anything else than a string ? switch output := output.(type) { case string: value = output @@ -282,12 +339,12 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri continue } - //collect the key + // collect the key output, err = exprhelpers.Run(stash.KeyExpression, cachedExprEnv, clog, n.Debug) if err != nil { clog.Warningf("Error while running stash key expression : %v", err) } - //can we expect anything else than a string ? + // can we expect anything else than a string ? switch output := output.(type) { case string: key = output @@ -299,15 +356,18 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } } - //Iterate on leafs + // Iterate on leafs for _, leaf := range n.LeavesNodes { ret, err := leaf.process(p, ctx, cachedExprEnv) if err != nil { clog.Tracef("\tNode (%s) failed : %v", leaf.rn, err) clog.Debugf("Event leaving node : ko") + return false, err } + clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaf.rn, ret, n.OnSuccess) + if ret { NodeState = true /* if child is successful, stop processing */ @@ -328,12 +388,14 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri clog.Tracef("State after nodes : %v", NodeState) - //grok or leafs failed, don't process statics + // grok or leafs failed, don't process statics if !NodeState { if n.Name != "" { NodesHitsKo.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() } + clog.Debugf("Event leaving node : ko") + return NodeState, nil } @@ -360,9 +422,10 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri if NodeState { clog.Debugf("Event leaving node : ok") log.Tracef("node is successful, check strategy") + if n.OnSuccess == "next_stage" { idx := stageidx(p.Stage, ctx.Stages) - //we're at the last stage + // we're at the last stage if idx+1 == len(ctx.Stages) { clog.Debugf("node reached the last stage : %s", p.Stage) } else { @@ -375,15 +438,16 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } else { clog.Debugf("Event leaving node : ko") } + clog.Tracef("Node successful, continue") + return NodeState, nil } func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { var err error - var valid bool - valid = false + valid := false dumpr := spew.ConfigState{MaxDepth: 1, DisablePointerAddresses: true} n.rn = seed.Generate() @@ -393,20 +457,17 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* if the node has debugging enabled, create a specific logger with debug that will be used only for processing this node ;) */ if n.Debug { - var clog = log.New() + clog := log.New() if err = types.ConfigureLogger(clog); err != nil { - log.Fatalf("While creating bucket-specific logger : %s", err) + return fmt.Errorf("while creating bucket-specific logger: %w", err) } + clog.SetLevel(log.DebugLevel) - n.Logger = clog.WithFields(log.Fields{ - "id": n.rn, - }) + n.Logger = clog.WithField("id", n.rn) n.Logger.Infof("%s has debug enabled", n.Name) } else { /* else bind it to the default one (might find something more elegant here)*/ - n.Logger = log.WithFields(log.Fields{ - "id": n.rn, - }) + n.Logger = log.WithField("id", n.rn) } /* display info about top-level nodes, they should be the only one with explicit stage name ?*/ @@ -414,7 +475,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { n.Logger.Tracef("Compiling : %s", dumpr.Sdump(n)) - //compile filter if present + // compile filter if present if n.Filter != "" { n.RunTimeFilter, err = expr.Compile(n.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { @@ -425,12 +486,15 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* handle pattern_syntax and groks */ for _, pattern := range n.SubGroks { n.Logger.Tracef("Adding subpattern '%s' : '%s'", pattern.Key, pattern.Value) + if err = pctx.Grok.Add(pattern.Key.(string), pattern.Value.(string)); err != nil { if errors.Is(err, grokky.ErrAlreadyExist) { n.Logger.Warningf("grok '%s' already registred", pattern.Key) continue } + n.Logger.Errorf("Unable to compile subpattern %s : %v", pattern.Key, err) + return err } } @@ -438,28 +502,36 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* load grok by name or compile in-place */ if n.Grok.RegexpName != "" { n.Logger.Tracef("+ Regexp Compilation '%s'", n.Grok.RegexpName) + n.Grok.RunTimeRegexp, err = pctx.Grok.Get(n.Grok.RegexpName) if err != nil { - return fmt.Errorf("unable to find grok '%s' : %v", n.Grok.RegexpName, err) + return fmt.Errorf("unable to find grok '%s': %v", n.Grok.RegexpName, err) } + if n.Grok.RunTimeRegexp == nil { return fmt.Errorf("empty grok '%s'", n.Grok.RegexpName) } + n.Logger.Tracef("%s regexp: %s", n.Grok.RegexpName, n.Grok.RunTimeRegexp.String()) + valid = true } else if n.Grok.RegexpValue != "" { if strings.HasSuffix(n.Grok.RegexpValue, "\n") { n.Logger.Debugf("Beware, pattern ends with \\n : '%s'", n.Grok.RegexpValue) } + n.Grok.RunTimeRegexp, err = pctx.Grok.Compile(n.Grok.RegexpValue) if err != nil { return fmt.Errorf("failed to compile grok '%s': %v", n.Grok.RegexpValue, err) } + if n.Grok.RunTimeRegexp == nil { // We shouldn't be here because compilation succeeded, so regexp shouldn't be nil return fmt.Errorf("grok compilation failure: %s", n.Grok.RegexpValue) } + n.Logger.Tracef("%s regexp : %s", n.Grok.RegexpValue, n.Grok.RunTimeRegexp.String()) + valid = true } @@ -473,7 +545,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } /* load grok statics */ - //compile expr statics if present + // compile expr statics if present for idx := range n.Grok.Statics { if n.Grok.Statics[idx].ExpValue != "" { n.Grok.Statics[idx].RunTimeValue, err = expr.Compile(n.Grok.Statics[idx].ExpValue, @@ -482,6 +554,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { return err } } + valid = true } @@ -505,7 +578,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } logLvl := n.Logger.Logger.GetLevel() - //init the cache, does it make sense to create it here just to be sure everything is fine ? + // init the cache, does it make sense to create it here just to be sure everything is fine ? if err = cache.CacheInit(cache.CacheCfg{ Size: n.Stash[i].MaxMapSize, TTL: n.Stash[i].TTLVal, @@ -526,14 +599,18 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { if !n.LeavesNodes[idx].Debug && n.Debug { n.LeavesNodes[idx].Debug = true } + if !n.LeavesNodes[idx].Profiling && n.Profiling { n.LeavesNodes[idx].Profiling = true } + n.LeavesNodes[idx].Stage = n.Stage + err = n.LeavesNodes[idx].compile(pctx, ectx) if err != nil { return err } + valid = true } @@ -546,6 +623,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { return err } } + valid = true } @@ -554,18 +632,16 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { if err != nil { return err } + valid = valid || whitelistValid if !valid { /* node is empty, error force return */ n.Logger.Error("Node is empty or invalid, abort") n.Stage = "" - return fmt.Errorf("Node is empty") - } - if err := n.validate(pctx, ectx); err != nil { - return err + return errors.New("Node is empty") } - return nil + return n.validate(ectx) } diff --git a/pkg/parser/node_test.go b/pkg/parser/node_test.go index d85aa82a8ae..76d35a9ffb0 100644 --- a/pkg/parser/node_test.go +++ b/pkg/parser/node_test.go @@ -49,18 +49,18 @@ func TestParserConfigs(t *testing.T) { } for idx := range CfgTests { err := CfgTests[idx].NodeCfg.compile(pctx, EnricherCtx{}) - if CfgTests[idx].Compiles == true && err != nil { + if CfgTests[idx].Compiles && err != nil { t.Fatalf("Compile: (%d/%d) expected valid, got : %s", idx+1, len(CfgTests), err) } - if CfgTests[idx].Compiles == false && err == nil { + if !CfgTests[idx].Compiles && err == nil { t.Fatalf("Compile: (%d/%d) expected error", idx+1, len(CfgTests)) } - err = CfgTests[idx].NodeCfg.validate(pctx, EnricherCtx{}) - if CfgTests[idx].Valid == true && err != nil { + err = CfgTests[idx].NodeCfg.validate(EnricherCtx{}) + if CfgTests[idx].Valid && err != nil { t.Fatalf("Valid: (%d/%d) expected valid, got : %s", idx+1, len(CfgTests), err) } - if CfgTests[idx].Valid == false && err == nil { + if !CfgTests[idx].Valid && err == nil { t.Fatalf("Valid: (%d/%d) expected error", idx+1, len(CfgTests)) } } diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index 04d08cc2785..269d51a1ba2 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -24,17 +24,21 @@ type TestFile struct { Results []types.Event `yaml:"results,omitempty"` } -var debug bool = false +var debug = false func TestParser(t *testing.T) { debug = true + log.SetLevel(log.InfoLevel) - var envSetting = os.Getenv("TEST_ONLY") + + envSetting := os.Getenv("TEST_ONLY") + pctx, ectx, err := prepTests() if err != nil { t.Fatalf("failed to load env : %s", err) } - //Init the enricher + + // Init the enricher if envSetting != "" { if err := testOneParser(pctx, ectx, envSetting, nil); err != nil { t.Fatalf("Test '%s' failed : %s", envSetting, err) @@ -44,12 +48,15 @@ func TestParser(t *testing.T) { if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { if !fd.IsDir() { continue } + fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) + if err := testOneParser(pctx, ectx, fname, nil); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } @@ -59,13 +66,17 @@ func TestParser(t *testing.T) { func BenchmarkParser(t *testing.B) { log.Printf("start bench !!!!") + debug = false + log.SetLevel(log.ErrorLevel) + pctx, ectx, err := prepTests() if err != nil { t.Fatalf("failed to load env : %s", err) } - var envSetting = os.Getenv("TEST_ONLY") + + envSetting := os.Getenv("TEST_ONLY") if envSetting != "" { if err := testOneParser(pctx, ectx, envSetting, t); err != nil { @@ -76,12 +87,15 @@ func BenchmarkParser(t *testing.B) { if err != nil { t.Fatalf("Unable to read test directory : %s", err) } + for _, fd := range fds { if !fd.IsDir() { continue } + fname := "./tests/" + fd.Name() log.Infof("Running test on %s", fname) + if err := testOneParser(pctx, ectx, fname, t); err != nil { t.Fatalf("Test '%s' failed : %s", fname, err) } @@ -91,49 +105,58 @@ func BenchmarkParser(t *testing.B) { func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing.B) error { var ( - err error - pnodes []Node - + err error + pnodes []Node parser_configs []Stagefile ) + log.Warningf("testing %s", dir) + parser_cfg_file := fmt.Sprintf("%s/parsers.yaml", dir) + cfg, err := os.ReadFile(parser_cfg_file) if err != nil { - return fmt.Errorf("failed opening %s : %s", parser_cfg_file, err) + return fmt.Errorf("failed opening %s: %w", parser_cfg_file, err) } + tmpl, err := template.New("test").Parse(string(cfg)) if err != nil { - return fmt.Errorf("failed to parse template %s : %s", cfg, err) + return fmt.Errorf("failed to parse template %s: %w", cfg, err) } + var out bytes.Buffer + err = tmpl.Execute(&out, map[string]string{"TestDirectory": dir}) if err != nil { panic(err) } + if err = yaml.UnmarshalStrict(out.Bytes(), &parser_configs); err != nil { - return fmt.Errorf("failed unmarshaling %s : %s", parser_cfg_file, err) + return fmt.Errorf("failed to parse %s: %w", parser_cfg_file, err) } pnodes, err = LoadStages(parser_configs, pctx, ectx) if err != nil { - return fmt.Errorf("unable to load parser config : %s", err) + return fmt.Errorf("unable to load parser config: %w", err) } - //TBD: Load post overflows - //func testFile(t *testing.T, file string, pctx UnixParserCtx, nodes []Node) bool { + // TBD: Load post overflows + // func testFile(t *testing.T, file string, pctx UnixParserCtx, nodes []Node) bool { parser_test_file := fmt.Sprintf("%s/test.yaml", dir) tests := loadTestFile(parser_test_file) count := 1 + if b != nil { count = b.N b.ResetTimer() } - for n := 0; n < count; n++ { - if testFile(tests, *pctx, pnodes) != true { - return fmt.Errorf("test failed !") + + for range(count) { + if !testFile(tests, *pctx, pnodes) { + return errors.New("test failed") } } + return nil } @@ -147,26 +170,34 @@ func prepTests() (*UnixParserCtx, EnricherCtx, error) { err = exprhelpers.Init(nil) if err != nil { - log.Fatalf("exprhelpers init failed: %s", err) + return nil, ectx, fmt.Errorf("exprhelpers init failed: %w", err) } - //Load enrichment + // Load enrichment datadir := "./test_data/" - ectx, err = Loadplugin(datadir) + + err = exprhelpers.GeoIPInit(datadir) if err != nil { - log.Fatalf("failed to load plugin geoip : %v", err) + log.Fatalf("unable to initialize GeoIP: %s", err) } + + ectx, err = Loadplugin() + if err != nil { + return nil, ectx, fmt.Errorf("failed to load plugin geoip: %v", err) + } + log.Printf("Loaded -> %+v", ectx) - //Load the parser patterns + // Load the parser patterns cfgdir := "../../config/" /* this should be refactored to 2 lines :p */ // Init the parser pctx, err = Init(map[string]interface{}{"patterns": cfgdir + string("/patterns/"), "data": "./tests/"}) if err != nil { - return nil, ectx, fmt.Errorf("failed to initialize parser : %v", err) + return nil, ectx, fmt.Errorf("failed to initialize parser: %v", err) } + return pctx, ectx, nil } @@ -175,43 +206,54 @@ func loadTestFile(file string) []TestFile { if err != nil { log.Fatalf("yamlFile.Get err #%v ", err) } + dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) + var testSet []TestFile + for { tf := TestFile{} + err := dec.Decode(&tf) if err != nil { if errors.Is(err, io.EOF) { break } + log.Fatalf("Failed to load testfile '%s' yaml error : %v", file, err) + return nil } + testSet = append(testSet, tf) } + return testSet } func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bool) { var retInfo []string - var valid = false + + valid := false expectMaps := []map[string]string{expected.Parsed, expected.Meta, expected.Enriched} outMaps := []map[string]string{out.Parsed, out.Meta, out.Enriched} outLabels := []string{"Parsed", "Meta", "Enriched"} - //allow to check as well for stage and processed flags + // allow to check as well for stage and processed flags if expected.Stage != "" { if expected.Stage != out.Stage { if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch stage %s != %s", expected.Stage, out.Stage)) } + goto checkFinished - } else { - valid = true - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok stage %s == %s", expected.Stage, out.Stage)) - } + } + + valid = true + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok stage %s == %s", expected.Stage, out.Stage)) } } @@ -219,48 +261,58 @@ func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bo if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch process %t != %t", expected.Process, out.Process)) } + goto checkFinished - } else { - valid = true - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok process %t == %t", expected.Process, out.Process)) - } + } + + valid = true + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok process %t == %t", expected.Process, out.Process)) } if expected.Whitelisted != out.Whitelisted { if debug { retInfo = append(retInfo, fmt.Sprintf("mismatch whitelist %t != %t", expected.Whitelisted, out.Whitelisted)) } + goto checkFinished - } else { - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok whitelist %t == %t", expected.Whitelisted, out.Whitelisted)) - } - valid = true } - for mapIdx := 0; mapIdx < len(expectMaps); mapIdx++ { + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok whitelist %t == %t", expected.Whitelisted, out.Whitelisted)) + } + + valid = true + + for mapIdx := range(len(expectMaps)) { for expKey, expVal := range expectMaps[mapIdx] { - if outVal, ok := outMaps[mapIdx][expKey]; ok { - if outVal == expVal { //ok entry - if debug { - retInfo = append(retInfo, fmt.Sprintf("ok %s[%s] %s == %s", outLabels[mapIdx], expKey, expVal, outVal)) - } - valid = true - } else { //mismatch entry - if debug { - retInfo = append(retInfo, fmt.Sprintf("mismatch %s[%s] %s != %s", outLabels[mapIdx], expKey, expVal, outVal)) - } - valid = false - goto checkFinished - } - } else { //missing entry + outVal, ok := outMaps[mapIdx][expKey] + if !ok { if debug { retInfo = append(retInfo, fmt.Sprintf("missing entry %s[%s]", outLabels[mapIdx], expKey)) } + valid = false + goto checkFinished } + + if outVal != expVal { // ok entry + if debug { + retInfo = append(retInfo, fmt.Sprintf("mismatch %s[%s] %s != %s", outLabels[mapIdx], expKey, expVal, outVal)) + } + + valid = false + + goto checkFinished + } + + if debug { + retInfo = append(retInfo, fmt.Sprintf("ok %s[%s] %s == %s", outLabels[mapIdx], expKey, expVal, outVal)) + } + + valid = true } } checkFinished: @@ -273,6 +325,7 @@ checkFinished: retInfo = append(retInfo, fmt.Sprintf("KO ! \n\t%s", strings.Join(retInfo, "\n\t"))) } } + return retInfo, valid } @@ -284,9 +337,10 @@ func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error if err != nil { log.Errorf("Failed to process %s : %v", spew.Sdump(in), err) } - //log.Infof("Parser output : %s", spew.Sdump(out)) + // log.Infof("Parser output : %s", spew.Sdump(out)) results = append(results, out) } + log.Infof("parsed %d lines", len(testSet.Lines)) log.Infof("got %d results", len(results)) @@ -295,21 +349,22 @@ func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error only the keys of the expected part are checked against result */ if len(testSet.Results) == 0 && len(results) == 0 { - log.Fatal("No results, no tests, abort.") - return false, fmt.Errorf("no tests, no results") + return false, errors.New("no tests, no results") } reCheck: failinfo := []string{} + for ridx, result := range results { for eidx, expected := range testSet.Results { explain, match := matchEvent(expected, result, debug) - if match == true { + if match { log.Infof("expected %d/%d matches result %d/%d", eidx, len(testSet.Results), ridx, len(results)) + if len(explain) > 0 { log.Printf("-> %s", explain[len(explain)-1]) } - //don't do this at home : delete current element from list and redo + // don't do this at home : delete current element from list and redo results[len(results)-1], results[ridx] = results[ridx], results[len(results)-1] results = results[:len(results)-1] @@ -317,34 +372,40 @@ reCheck: testSet.Results = testSet.Results[:len(testSet.Results)-1] goto reCheck - } else { - failinfo = append(failinfo, explain...) } + + failinfo = append(failinfo, explain...) } } + if len(results) > 0 { log.Printf("Errors : %s", strings.Join(failinfo, " / ")) return false, fmt.Errorf("leftover results : %+v", results) } + if len(testSet.Results) > 0 { log.Printf("Errors : %s", strings.Join(failinfo, " / ")) return false, fmt.Errorf("leftover expected results : %+v", testSet.Results) } + return true, nil } func testFile(testSet []TestFile, pctx UnixParserCtx, nodes []Node) bool { log.Warning("Going to process one test set") + for _, tf := range testSet { - //func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error) { + // func testSubSet(testSet TestFile, pctx UnixParserCtx, nodes []Node) (bool, error) { testOk, err := testSubSet(tf, pctx, nodes) if err != nil { log.Fatalf("test failed : %s", err) } + if !testOk { log.Fatalf("failed test : %+v", tf) } } + return true } @@ -369,48 +430,61 @@ func TestGeneratePatternsDoc(t *testing.T) { if err != nil { t.Fatalf("unable to load patterns : %s", err) } + log.Infof("-> %s", spew.Sdump(pctx)) /*don't judge me, we do it for the users*/ p := make(PairList, len(pctx.Grok.Patterns)) i := 0 + for key, val := range pctx.Grok.Patterns { p[i] = Pair{key, val} p[i].Value = strings.ReplaceAll(p[i].Value, "{%{", "\\{\\%\\{") i++ } + sort.Sort(p) - f, err := os.OpenFile("./patterns-documentation.md", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile("./patterns-documentation.md", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { t.Fatalf("failed to open : %s", err) } + if _, err := f.WriteString("# Patterns documentation\n\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("You will find here a generated documentation of all the patterns loaded by crowdsec.\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("They are sorted by pattern length, and are meant to be used in parsers, in the form %{PATTERN_NAME}.\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("\n\n"); err != nil { t.Fatal("failed to write to file") } + for _, k := range p { if _, err := fmt.Fprintf(f, "## %s\n\nPattern :\n```\n%s\n```\n\n", k.Key, k.Value); err != nil { t.Fatal("failed to write to file") } + fmt.Printf("%v\t%v\n", k.Key, k.Value) } + if _, err := f.WriteString("\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("# Documentation generation\n"); err != nil { t.Fatal("failed to write to file") } + if _, err := f.WriteString("This documentation is generated by `pkg/parser` : `GO_WANT_TEST_DOC=1 go test -run TestGeneratePatternsDoc`\n"); err != nil { t.Fatal("failed to write to file") } + f.Close() } diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index 4f4f6a0f3d0..8068690b68f 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -42,8 +42,8 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { iter := reflect.ValueOf(evt).Elem() if (iter == reflect.Value{}) || iter.IsZero() { - log.Tracef("event is nill") - //event is nill + log.Tracef("event is nil") + //event is nil return false } for _, f := range strings.Split(target, ".") { @@ -155,7 +155,7 @@ func (n *Node) ProcessStatics(statics []ExtraField, event *types.Event) error { /*still way too hackish, but : inject all the results in enriched, and */ if enricherPlugin, ok := n.EnrichFunctions.Registered[static.Method]; ok { clog.Tracef("Found method '%s'", static.Method) - ret, err := enricherPlugin.EnrichFunc(value, event, enricherPlugin.Ctx, n.Logger.WithField("method", static.Method)) + ret, err := enricherPlugin.EnrichFunc(value, event, n.Logger.WithField("method", static.Method)) if err != nil { clog.Errorf("method '%s' returned an error : %v", static.Method, err) } @@ -221,6 +221,24 @@ var NodesHitsKo = prometheus.NewCounterVec( []string{"source", "type", "name"}, ) +// + +var NodesWlHitsOk = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_ok_total", + Help: "Total events successfully whitelisted by node.", + }, + []string{"source", "type", "name", "reason"}, +) + +var NodesWlHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_node_wl_hits_total", + Help: "Total events processed by whitelist node.", + }, + []string{"source", "type", "name", "reason"}, +) + func stageidx(stage string, stages []string) int { for i, v := range stages { if stage == v { diff --git a/pkg/parser/stage.go b/pkg/parser/stage.go index 1eac2b83ede..b98db350254 100644 --- a/pkg/parser/stage.go +++ b/pkg/parser/stage.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + // enable profiling _ "net/http/pprof" "os" "sort" @@ -20,7 +21,7 @@ import ( log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) @@ -84,12 +85,12 @@ func LoadStages(stageFiles []Stagefile, pctx *UnixParserCtx, ectx EnricherCtx) ( log.Tracef("no version in %s, assuming '1.0'", node.Name) node.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(node.FormatVersion, cwversion.Constraint_parser) + ok, err := constraint.Satisfies(node.FormatVersion, constraint.Parser) if err != nil { return nil, fmt.Errorf("failed to check version : %s", err) } if !ok { - log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, cwversion.Constraint_parser) + log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, constraint.Parser) continue } diff --git a/pkg/parser/unix_parser.go b/pkg/parser/unix_parser.go index 617e46189f3..351de8ade56 100644 --- a/pkg/parser/unix_parser.go +++ b/pkg/parser/unix_parser.go @@ -66,21 +66,20 @@ func NewParsers(hub *cwhub.Hub) *Parsers { } for _, itemType := range []string{cwhub.PARSERS, cwhub.POSTOVERFLOWS} { - for _, hubParserItem := range hub.GetItemMap(itemType) { - if hubParserItem.State.Installed { - stagefile := Stagefile{ - Filename: hubParserItem.State.LocalPath, - Stage: hubParserItem.Stage, - } - if itemType == cwhub.PARSERS { - parsers.StageFiles = append(parsers.StageFiles, stagefile) - } - if itemType == cwhub.POSTOVERFLOWS { - parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) - } + for _, hubParserItem := range hub.GetInstalledByType(itemType, false) { + stagefile := Stagefile{ + Filename: hubParserItem.State.LocalPath, + Stage: hubParserItem.Stage, + } + if itemType == cwhub.PARSERS { + parsers.StageFiles = append(parsers.StageFiles, stagefile) + } + if itemType == cwhub.POSTOVERFLOWS { + parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) } } } + if parsers.StageFiles != nil { sort.Slice(parsers.StageFiles, func(i, j int) bool { return parsers.StageFiles[i].Filename < parsers.StageFiles[j].Filename @@ -98,16 +97,20 @@ func NewParsers(hub *cwhub.Hub) *Parsers { func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { var err error - patternsDir := filepath.Join(cConfig.ConfigPaths.ConfigDir, "patterns/") + patternsDir := cConfig.ConfigPaths.PatternDir log.Infof("Loading grok library %s", patternsDir) /* load base regexps for two grok parsers */ - parsers.Ctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.ConfigPaths.DataDir}) + parsers.Ctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load parser patterns : %v", err) } - parsers.Povfwctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.ConfigPaths.DataDir}) + parsers.Povfwctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load postovflw parser patterns : %v", err) } @@ -117,7 +120,7 @@ func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { */ log.Infof("Loading enrich plugins") - parsers.EnricherCtx, err = Loadplugin(cConfig.ConfigPaths.DataDir) + parsers.EnricherCtx, err = Loadplugin() if err != nil { return parsers, fmt.Errorf("failed to load enrich plugin : %v", err) } diff --git a/pkg/parser/whitelist.go b/pkg/parser/whitelist.go index 027a9a2858a..e7b93a8d7da 100644 --- a/pkg/parser/whitelist.go +++ b/pkg/parser/whitelist.go @@ -4,8 +4,10 @@ import ( "fmt" "net" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/prometheus/client_golang/prometheus" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -36,11 +38,13 @@ func (n *Node) ContainsIPLists() bool { return len(n.Whitelist.B_Ips) > 0 || len(n.Whitelist.B_Cidrs) > 0 } -func (n *Node) CheckIPsWL(srcs []net.IP) bool { +func (n *Node) CheckIPsWL(p *types.Event) bool { + srcs := p.ParseIPSources() isWhitelisted := false if !n.ContainsIPLists() { return isWhitelisted } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() for _, src := range srcs { if isWhitelisted { break @@ -62,15 +66,19 @@ func (n *Node) CheckIPsWL(srcs []net.IP) bool { n.Logger.Tracef("whitelist: %s not in [%s]", src, v) } } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } return isWhitelisted } -func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) { +func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}, p *types.Event) (bool, error) { isWhitelisted := false if !n.ContainsExprLists() { return false, nil } + NodesWlHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() /* run whitelist expression tests anyway */ for eidx, e := range n.Whitelist.B_Exprs { //if we already know the event is whitelisted, skip the rest of the expressions @@ -94,6 +102,9 @@ func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) { n.Logger.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) } } + if isWhitelisted { + NodesWlHitsOk.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name, "reason": n.Whitelist.Reason}).Inc() + } return isWhitelisted, nil } diff --git a/pkg/parser/whitelist_test.go b/pkg/parser/whitelist_test.go index 8796aaedafe..02846f17fc1 100644 --- a/pkg/parser/whitelist_test.go +++ b/pkg/parser/whitelist_test.go @@ -62,7 +62,6 @@ func TestWhitelistCompile(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { node.Whitelist = tt.whitelist _, err := node.CompileWLs() @@ -284,14 +283,13 @@ func TestWhitelistCheck(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { var err error node.Whitelist = tt.whitelist node.CompileWLs() - isWhitelisted := node.CheckIPsWL(tt.event.ParseIPSources()) + isWhitelisted := node.CheckIPsWL(tt.event) if !isWhitelisted { - isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}) + isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}, tt.event) } require.NoError(t, err) require.Equal(t, tt.expected, isWhitelisted) diff --git a/pkg/protobufs/generate.go b/pkg/protobufs/generate.go new file mode 100644 index 00000000000..0e90d65b643 --- /dev/null +++ b/pkg/protobufs/generate.go @@ -0,0 +1,14 @@ +package protobufs + +// Dependencies: +// +// apt install protobuf-compiler +// +// keep this in sync with go.mod +// go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 +// +// Not the same versions as google.golang.org/grpc +// go list -m -versions google.golang.org/grpc/cmd/protoc-gen-go-grpc +// go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative notifier.proto diff --git a/pkg/protobufs/notifier.pb.go b/pkg/protobufs/notifier.pb.go index b5dc8113568..8c4754da773 100644 --- a/pkg/protobufs/notifier.pb.go +++ b/pkg/protobufs/notifier.pb.go @@ -1,16 +1,12 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v3.12.4 +// protoc-gen-go v1.34.2 +// protoc v3.21.12 // source: notifier.proto package protobufs import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -198,7 +194,7 @@ func file_notifier_proto_rawDescGZIP() []byte { } var file_notifier_proto_msgTypes = make([]protoimpl.MessageInfo, 3) -var file_notifier_proto_goTypes = []interface{}{ +var file_notifier_proto_goTypes = []any{ (*Notification)(nil), // 0: proto.Notification (*Config)(nil), // 1: proto.Config (*Empty)(nil), // 2: proto.Empty @@ -221,7 +217,7 @@ func file_notifier_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_notifier_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Notification); i { case 0: return &v.state @@ -233,7 +229,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*Config); i { case 0: return &v.state @@ -245,7 +241,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*Empty); i { case 0: return &v.state @@ -277,119 +273,3 @@ func file_notifier_proto_init() { file_notifier_proto_goTypes = nil file_notifier_proto_depIdxs = nil } - -// Reference imports to suppress errors if they are not otherwise used. -var _ context.Context -var _ grpc.ClientConnInterface - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -const _ = grpc.SupportPackageIsVersion6 - -// NotifierClient is the client API for Notifier service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. -type NotifierClient interface { - Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) - Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) -} - -type notifierClient struct { - cc grpc.ClientConnInterface -} - -func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { - return ¬ifierClient{cc} -} - -func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Notify", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Configure", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// NotifierServer is the server API for Notifier service. -type NotifierServer interface { - Notify(context.Context, *Notification) (*Empty, error) - Configure(context.Context, *Config) (*Empty, error) -} - -// UnimplementedNotifierServer can be embedded to have forward compatible implementations. -type UnimplementedNotifierServer struct { -} - -func (*UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") -} -func (*UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") -} - -func RegisterNotifierServer(s *grpc.Server, srv NotifierServer) { - s.RegisterService(&_Notifier_serviceDesc, srv) -} - -func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Notification) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Notify(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Notify", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Notify(ctx, req.(*Notification)) - } - return interceptor(ctx, in, info, handler) -} - -func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Config) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Configure(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Configure", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Configure(ctx, req.(*Config)) - } - return interceptor(ctx, in, info, handler) -} - -var _Notifier_serviceDesc = grpc.ServiceDesc{ - ServiceName: "proto.Notifier", - HandlerType: (*NotifierServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Notify", - Handler: _Notifier_Notify_Handler, - }, - { - MethodName: "Configure", - Handler: _Notifier_Configure_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "notifier.proto", -} diff --git a/pkg/protobufs/notifier_grpc.pb.go b/pkg/protobufs/notifier_grpc.pb.go new file mode 100644 index 00000000000..5141e83f98b --- /dev/null +++ b/pkg/protobufs/notifier_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v3.21.12 +// source: notifier.proto + +package protobufs + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Notifier_Notify_FullMethodName = "/proto.Notifier/Notify" + Notifier_Configure_FullMethodName = "/proto.Notifier/Configure" +) + +// NotifierClient is the client API for Notifier service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type NotifierClient interface { + Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) + Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) +} + +type notifierClient struct { + cc grpc.ClientConnInterface +} + +func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { + return ¬ifierClient{cc} +} + +func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Notify_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Configure_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// NotifierServer is the server API for Notifier service. +// All implementations must embed UnimplementedNotifierServer +// for forward compatibility. +type NotifierServer interface { + Notify(context.Context, *Notification) (*Empty, error) + Configure(context.Context, *Config) (*Empty, error) + mustEmbedUnimplementedNotifierServer() +} + +// UnimplementedNotifierServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedNotifierServer struct{} + +func (UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") +} +func (UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") +} +func (UnimplementedNotifierServer) mustEmbedUnimplementedNotifierServer() {} +func (UnimplementedNotifierServer) testEmbeddedByValue() {} + +// UnsafeNotifierServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to NotifierServer will +// result in compilation errors. +type UnsafeNotifierServer interface { + mustEmbedUnimplementedNotifierServer() +} + +func RegisterNotifierServer(s grpc.ServiceRegistrar, srv NotifierServer) { + // If the following call pancis, it indicates UnimplementedNotifierServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Notifier_ServiceDesc, srv) +} + +func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Notification) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Notify(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Notify_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Notify(ctx, req.(*Notification)) + } + return interceptor(ctx, in, info, handler) +} + +func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Config) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Configure(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Configure_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Configure(ctx, req.(*Config)) + } + return interceptor(ctx, in, info, handler) +} + +// Notifier_ServiceDesc is the grpc.ServiceDesc for Notifier service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Notifier_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proto.Notifier", + HandlerType: (*NotifierServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Notify", + Handler: _Notifier_Notify_Handler, + }, + { + MethodName: "Configure", + Handler: _Notifier_Configure_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "notifier.proto", +} diff --git a/pkg/protobufs/plugin_interface.go b/pkg/protobufs/plugin_interface.go deleted file mode 100644 index fc89b2fa009..00000000000 --- a/pkg/protobufs/plugin_interface.go +++ /dev/null @@ -1,47 +0,0 @@ -package protobufs - -import ( - "context" - - plugin "github.com/hashicorp/go-plugin" - "google.golang.org/grpc" -) - -type Notifier interface { - Notify(ctx context.Context, notification *Notification) (*Empty, error) - Configure(ctx context.Context, config *Config) (*Empty, error) -} - -// This is the implementation of plugin.NotifierPlugin so we can serve/consume this. -type NotifierPlugin struct { - // GRPCPlugin must still implement the Plugin interface - plugin.Plugin - // Concrete implementation, written in Go. This is only used for plugins - // that are written in Go. - Impl Notifier -} - -type GRPCClient struct{ client NotifierClient } - -func (m *GRPCClient) Notify(ctx context.Context, notification *Notification) (*Empty, error) { - _, err := m.client.Notify(context.Background(), notification) - return &Empty{}, err -} - -func (m *GRPCClient) Configure(ctx context.Context, config *Config) (*Empty, error) { - _, err := m.client.Configure(context.Background(), config) - return &Empty{}, err -} - -type GRPCServer struct { - Impl Notifier -} - -func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { - RegisterNotifierServer(s, p.Impl) - return nil -} - -func (p *NotifierPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &GRPCClient{client: NewNotifierClient(c)}, nil -} diff --git a/pkg/setup/README.md b/pkg/setup/README.md index 3585ee8b141..9cdc7243975 100644 --- a/pkg/setup/README.md +++ b/pkg/setup/README.md @@ -129,7 +129,7 @@ services: and must all return true for a service to be detected (implied *and* clause, no short-circuit). A missing or empty `when:` section is evaluated as true. The [expression -engine](https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md) +engine](https://github.com/antonmedv/expr/blob/master/docs/language-definition.md) is the same one used by CrowdSec parser filters. You can force the detection of a process by using the `cscli setup detect... --force-process ` flag. It will always behave as if `` was running. diff --git a/pkg/setup/detect.go b/pkg/setup/detect.go index 7d73092f74e..073b221b10c 100644 --- a/pkg/setup/detect.go +++ b/pkg/setup/detect.go @@ -2,6 +2,7 @@ package setup import ( "bytes" + "errors" "fmt" "io" "os" @@ -9,8 +10,8 @@ import ( "sort" "github.com/Masterminds/semver/v3" - "github.com/antonmedv/expr" "github.com/blackfireio/osinfo" + "github.com/expr-lang/expr" "github.com/shirou/gopsutil/v3/process" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" @@ -53,6 +54,7 @@ func validateDataSource(opaqueDS DataSourceItem) error { // formally validate YAML commonDS := configuration.DataSourceCommonCfg{} + body, err := yaml.Marshal(opaqueDS) if err != nil { return err @@ -66,14 +68,14 @@ func validateDataSource(opaqueDS DataSourceItem) error { // source is mandatory // XXX unless it's not? if commonDS.Source == "" { - return fmt.Errorf("source is empty") + return errors.New("source is empty") } // source must be known - ds := acquisition.GetDataSourceIface(commonDS.Source) - if ds == nil { - return fmt.Errorf("unknown source '%s'", commonDS.Source) + ds, err := acquisition.GetDataSourceIface(commonDS.Source) + if err != nil { + return err } // unmarshal and validate the rest with the specific implementation @@ -104,7 +106,7 @@ func readDetectConfig(fin io.Reader) (DetectConfig, error) { switch dc.Version { case "": - return DetectConfig{}, fmt.Errorf("missing version tag (must be 1.0)") + return DetectConfig{}, errors.New("missing version tag (must be 1.0)") case "1.0": // all is well default: @@ -543,7 +545,7 @@ func Detect(detectReader io.Reader, opts DetectOptions) (Setup, error) { // } // err = yaml.Unmarshal(svc.AcquisYAML, svc.DataSource) // if err != nil { - // return Setup{}, fmt.Errorf("while unmarshaling datasource for service %s: %w", name, err) + // return Setup{}, fmt.Errorf("while parsing datasource for service %s: %w", name, err) // } // } diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index 242ade0494b..588e74dab54 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -94,11 +94,11 @@ func TestPathExists(t *testing.T) { } for _, tc := range tests { - tc := tc env := setup.NewExprEnvironment(setup.DetectOptions{}, setup.ExprOS{}) t.Run(tc.path, func(t *testing.T) { t.Parallel() + actual := env.PathExists(tc.path) require.Equal(t, tc.expected, actual) }) @@ -147,11 +147,11 @@ func TestVersionCheck(t *testing.T) { } for _, tc := range tests { - tc := tc e := setup.ExprOS{RawVersion: tc.version} t.Run(fmt.Sprintf("Check(%s,%s)", tc.version, tc.constraint), func(t *testing.T) { t.Parallel() + actual, err := e.VersionCheck(tc.constraint) cstest.RequireErrorContains(t, err, tc.expectedErr) require.Equal(t, tc.expected, actual) @@ -184,7 +184,6 @@ func TestNormalizeVersion(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.version, func(t *testing.T) { t.Parallel() actual := setup.NormalizeVersion(tc.version) @@ -246,11 +245,12 @@ func TestListSupported(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + f := tempYAML(t, tc.yml) defer os.Remove(f.Name()) + supported, err := setup.ListSupported(&f) cstest.RequireErrorContains(t, err, tc.expectedErr) require.ElementsMatch(t, tc.expected, supported) @@ -329,9 +329,9 @@ func TestApplyRules(t *testing.T) { env := setup.ExprEnvironment{} for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + svc := setup.Service{When: tc.rules} _, actualOk, err := setup.ApplyRules(svc, env) //nolint:typecheck,nolintlint // exported only for tests cstest.RequireErrorContains(t, err, tc.expectedErr) @@ -419,7 +419,6 @@ detect: } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) defer os.Remove(f.Name()) @@ -513,7 +512,6 @@ detect: } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) defer os.Remove(f.Name()) @@ -825,7 +823,6 @@ func TestDetectForcedOS(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) defer os.Remove(f.Name()) @@ -840,7 +837,6 @@ func TestDetectForcedOS(t *testing.T) { func TestDetectDatasourceValidation(t *testing.T) { // It could be a good idea to test UnmarshalConfig() separately in addition // to Configure(), in each datasource. For now, we test these here. - require := require.New(t) setup.ExecCommand = fakeExecCommand @@ -874,7 +870,7 @@ func TestDetectDatasourceValidation(t *testing.T) { datasource: source: wombat`, expected: setup.Setup{Setup: []setup.ServiceSetup{}}, - expectedErr: "invalid datasource for foobar: unknown source 'wombat'", + expectedErr: "invalid datasource for foobar: unknown data source wombat", }, { name: "source is misplaced", config: ` @@ -1011,7 +1007,6 @@ func TestDetectDatasourceValidation(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { f := tempYAML(t, tc.config) defer os.Remove(f.Name()) diff --git a/pkg/setup/install.go b/pkg/setup/install.go index fc922c5d19b..d63a1ee1775 100644 --- a/pkg/setup/install.go +++ b/pkg/setup/install.go @@ -2,6 +2,8 @@ package setup import ( "bytes" + "context" + "errors" "fmt" "os" "path/filepath" @@ -38,14 +40,14 @@ func decodeSetup(input []byte, fancyErrors bool) (Setup, error) { dec2.KnownFields(true) if err := dec2.Decode(&ret); err != nil { - return ret, fmt.Errorf("while unmarshaling setup file: %w", err) + return ret, fmt.Errorf("while parsing setup file: %w", err) } return ret, nil } // InstallHubItems installs the objects recommended in a setup file. -func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error { +func InstallHubItems(ctx context.Context, hub *cwhub.Hub, input []byte, dryRun bool) error { setupEnvelope, err := decodeSetup(input, false) if err != nil { return err @@ -60,79 +62,71 @@ func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error { continue } - if len(install.Collections) > 0 { - for _, collection := range setupItem.Install.Collections { - item := hub.GetItem(cwhub.COLLECTIONS, collection) - if item == nil { - return fmt.Errorf("collection %s not found", collection) - } + for _, collection := range setupItem.Install.Collections { + item := hub.GetItem(cwhub.COLLECTIONS, collection) + if item == nil { + return fmt.Errorf("collection %s not found", collection) + } - if dryRun { - fmt.Println("dry-run: would install collection", collection) + if dryRun { + fmt.Println("dry-run: would install collection", collection) - continue - } + continue + } - if err := item.Install(forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing collection %s: %w", item.Name, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing collection %s: %w", item.Name, err) } } - if len(install.Parsers) > 0 { - for _, parser := range setupItem.Install.Parsers { - if dryRun { - fmt.Println("dry-run: would install parser", parser) + for _, parser := range setupItem.Install.Parsers { + if dryRun { + fmt.Println("dry-run: would install parser", parser) - continue - } + continue + } - item := hub.GetItem(cwhub.PARSERS, parser) - if item == nil { - return fmt.Errorf("parser %s not found", parser) - } + item := hub.GetItem(cwhub.PARSERS, parser) + if item == nil { + return fmt.Errorf("parser %s not found", parser) + } - if err := item.Install(forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing parser %s: %w", item.Name, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing parser %s: %w", item.Name, err) } } - if len(install.Scenarios) > 0 { - for _, scenario := range setupItem.Install.Scenarios { - if dryRun { - fmt.Println("dry-run: would install scenario", scenario) + for _, scenario := range setupItem.Install.Scenarios { + if dryRun { + fmt.Println("dry-run: would install scenario", scenario) - continue - } + continue + } - item := hub.GetItem(cwhub.SCENARIOS, scenario) - if item == nil { - return fmt.Errorf("scenario %s not found", scenario) - } + item := hub.GetItem(cwhub.SCENARIOS, scenario) + if item == nil { + return fmt.Errorf("scenario %s not found", scenario) + } - if err := item.Install(forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing scenario %s: %w", item.Name, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing scenario %s: %w", item.Name, err) } } - if len(install.PostOverflows) > 0 { - for _, postoverflow := range setupItem.Install.PostOverflows { - if dryRun { - fmt.Println("dry-run: would install postoverflow", postoverflow) + for _, postoverflow := range setupItem.Install.PostOverflows { + if dryRun { + fmt.Println("dry-run: would install postoverflow", postoverflow) - continue - } + continue + } - item := hub.GetItem(cwhub.POSTOVERFLOWS, postoverflow) - if item == nil { - return fmt.Errorf("postoverflow %s not found", postoverflow) - } + item := hub.GetItem(cwhub.POSTOVERFLOWS, postoverflow) + if item == nil { + return fmt.Errorf("postoverflow %s not found", postoverflow) + } - if err := item.Install(forceAction, downloadOnly); err != nil { - return fmt.Errorf("while installing postoverflow %s: %w", item.Name, err) - } + if err := item.Install(ctx, forceAction, downloadOnly); err != nil { + return fmt.Errorf("while installing postoverflow %s: %w", item.Name, err) } } } @@ -173,7 +167,7 @@ func marshalAcquisDocuments(ads []AcquisDocument, toDir string) (string, error) if toDir != "" { if ad.AcquisFilename == "" { - return "", fmt.Errorf("empty acquis filename") + return "", errors.New("empty acquis filename") } fname := filepath.Join(toDir, ad.AcquisFilename) diff --git a/pkg/setup/units.go b/pkg/setup/units.go index a0bccba4aac..861513d3f1d 100644 --- a/pkg/setup/units.go +++ b/pkg/setup/units.go @@ -2,6 +2,7 @@ package setup import ( "bufio" + "errors" "fmt" "strings" @@ -34,14 +35,14 @@ func systemdUnitList() ([]string, error) { for scanner.Scan() { line := scanner.Text() - if len(line) == 0 { + if line == "" { break // the rest of the output is footer } if !header { spaceIdx := strings.IndexRune(line, ' ') if spaceIdx == -1 { - return ret, fmt.Errorf("can't parse systemctl output") + return ret, errors.New("can't parse systemctl output") } line = line[:spaceIdx] diff --git a/pkg/types/event.go b/pkg/types/event.go index 074241918d8..e016d0294c4 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -2,11 +2,12 @@ package types import ( "net" + "strings" "time" + "github.com/expr-lang/expr/vm" log "github.com/sirupsen/logrus" - "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -19,11 +20,11 @@ const ( // Event is the structure representing a runtime event (log or overflow) type Event struct { /* is it a log or an overflow */ - Type int `yaml:"Type,omitempty" json:"Type,omitempty"` //Can be types.LOG (0) or types.OVFLOW (1) - ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` //how to buckets should handle event : types.TIMEMACHINE or types.LIVE + Type int `yaml:"Type,omitempty" json:"Type,omitempty"` // Can be types.LOG (0) or types.OVFLOW (1) + ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` // how to buckets should handle event : types.TIMEMACHINE or types.LIVE Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` WhitelistReason string `yaml:"WhitelistReason,omitempty" json:"whitelist_reason,omitempty"` - //should add whitelist reason ? + // should add whitelist reason ? /* the current stage of the line being parsed */ Stage string `yaml:"Stage,omitempty" json:"Stage,omitempty"` /* original line (produced by acquisition) */ @@ -36,22 +37,43 @@ type Event struct { Unmarshaled map[string]interface{} `yaml:"Unmarshaled,omitempty" json:"Unmarshaled,omitempty"` /* Overflow */ Overflow RuntimeAlert `yaml:"Overflow,omitempty" json:"Alert,omitempty"` - Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` //parsed time `json:"-"` `` + Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` // parsed time `json:"-"` `` StrTime string `yaml:"StrTime,omitempty" json:"StrTime,omitempty"` StrTimeFormat string `yaml:"StrTimeFormat,omitempty" json:"StrTimeFormat,omitempty"` MarshaledTime string `yaml:"MarshaledTime,omitempty" json:"MarshaledTime,omitempty"` - Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` //can be set to false to avoid processing line + Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` // can be set to false to avoid processing line Appsec AppsecEvent `yaml:"Appsec,omitempty" json:"Appsec,omitempty"` /* Meta is the only part that will make it to the API - it should be normalized */ Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"` } +func (e *Event) SetMeta(key string, value string) bool { + if e.Meta == nil { + e.Meta = make(map[string]string) + } + + e.Meta[key] = value + + return true +} + +func (e *Event) SetParsed(key string, value string) bool { + if e.Parsed == nil { + e.Parsed = make(map[string]string) + } + + e.Parsed[key] = value + + return true +} + func (e *Event) GetType() string { - if e.Type == OVFLW { + switch e.Type { + case OVFLW: return "overflow" - } else if e.Type == LOG { + case LOG: return "log" - } else { + default: log.Warningf("unknown event type for %+v", e) return "unknown" } @@ -73,11 +95,13 @@ func (e *Event) GetMeta(key string) string { } } } + return "" } func (e *Event) ParseIPSources() []net.IP { var srcs []net.IP + switch e.Type { case LOG: if _, ok := e.Meta["source_ip"]; ok { @@ -88,6 +112,7 @@ func (e *Event) ParseIPSources() []net.IP { srcs = append(srcs, net.ParseIP(k)) } } + return srcs } @@ -114,8 +139,8 @@ type RuntimeAlert struct { Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` Reprocess bool `yaml:"Reprocess,omitempty" json:"Reprocess,omitempty"` Sources map[string]models.Source `yaml:"Sources,omitempty" json:"Sources,omitempty"` - Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` //this one is a pointer to APIAlerts[0] for convenience. - //APIAlerts will be populated at the end when there is more than one source + Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` // this one is a pointer to APIAlerts[0] for convenience. + // APIAlerts will be populated at the end when there is more than one source APIAlerts []models.Alert `yaml:"APIAlerts,omitempty" json:"APIAlerts,omitempty"` } @@ -124,5 +149,21 @@ func (r RuntimeAlert) GetSources() []string { for key := range r.Sources { ret = append(ret, key) } + return ret } + +func NormalizeScope(scope string) string { + switch strings.ToLower(scope) { + case "ip": + return Ip + case "range": + return Range + case "as": + return AS + case "country": + return Country + default: + return scope + } +} diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go index 14ca48cd2a8..97b13f96d9a 100644 --- a/pkg/types/event_test.go +++ b/pkg/types/event_test.go @@ -9,6 +9,86 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) +func TestSetParsed(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetParsed: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map", + evt: &Event{Parsed: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetParsed: Existing map+key", + evt: &Event{Parsed: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetParsed(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.Parsed[tt.key]) + }) + } + +} + +func TestSetMeta(t *testing.T) { + tests := []struct { + name string + evt *Event + key string + value string + expected bool + }{ + { + name: "SetMeta: Valid", + evt: &Event{}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map", + evt: &Event{Meta: map[string]string{}}, + key: "test", + value: "test", + expected: true, + }, + { + name: "SetMeta: Existing map+key", + evt: &Event{Meta: map[string]string{"test": "foobar"}}, + key: "test", + value: "test", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.evt.SetMeta(tt.key, tt.value) + assert.Equal(t, tt.value, tt.evt.GetMeta(tt.key)) + }) + } + +} + func TestParseIPSources(t *testing.T) { tests := []struct { name string @@ -70,7 +150,6 @@ func TestParseIPSources(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { ips := tt.evt.ParseIPSources() assert.Equal(t, tt.expected, ips) diff --git a/pkg/types/getfstype.go b/pkg/types/getfstype.go new file mode 100644 index 00000000000..728e986bed0 --- /dev/null +++ b/pkg/types/getfstype.go @@ -0,0 +1,115 @@ +//go:build !windows && !freebsd && !openbsd + +package types + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// Generated with `man statfs | grep _MAGIC | awk '{split(tolower($1),a,"_"); print $2 ": \"" a[1] "\","}'` +// ext2/3/4 duplicates removed to just have ext4 +// XIAFS removed as well +var fsTypeMapping = map[int64]string{ + 0xadf5: "adfs", + 0xadff: "affs", + 0x5346414f: "afs", + 0x09041934: "anon", + 0x0187: "autofs", + 0x62646576: "bdevfs", + 0x42465331: "befs", + 0x1badface: "bfs", + 0x42494e4d: "binfmtfs", + 0xcafe4a11: "bpf", + 0x9123683e: "btrfs", + 0x73727279: "btrfs", + 0x27e0eb: "cgroup", + 0x63677270: "cgroup2", + 0xff534d42: "cifs", + 0x73757245: "coda", + 0x012ff7b7: "coh", + 0x28cd3d45: "cramfs", + 0x64626720: "debugfs", + 0x1373: "devfs", + 0x1cd1: "devpts", + 0xf15f: "ecryptfs", + 0xde5e81e4: "efivarfs", + 0x00414a53: "efs", + 0x137d: "ext", + 0xef51: "ext2", + 0xef53: "ext4", + 0xf2f52010: "f2fs", + 0x65735546: "fuse", + 0xbad1dea: "futexfs", + 0x4244: "hfs", + 0x00c0ffee: "hostfs", + 0xf995e849: "hpfs", + 0x958458f6: "hugetlbfs", + 0x9660: "isofs", + 0x72b6: "jffs2", + 0x3153464a: "jfs", + 0x137f: "minix", + 0x138f: "minix", + 0x2468: "minix2", + 0x2478: "minix2", + 0x4d5a: "minix3", + 0x19800202: "mqueue", + 0x4d44: "msdos", + 0x11307854: "mtd", + 0x564c: "ncp", + 0x6969: "nfs", + 0x3434: "nilfs", + 0x6e736673: "nsfs", + 0x5346544e: "ntfs", + 0x7461636f: "ocfs2", + 0x9fa1: "openprom", + 0x794c7630: "overlayfs", + 0x50495045: "pipefs", + 0x9fa0: "proc", + 0x6165676c: "pstorefs", + 0x002f: "qnx4", + 0x68191122: "qnx6", + 0x858458f6: "ramfs", + 0x52654973: "reiserfs", + 0x7275: "romfs", + 0x73636673: "securityfs", + 0xf97cff8c: "selinux", + 0x43415d53: "smack", + 0x517b: "smb", + 0xfe534d42: "smb2", + 0x534f434b: "sockfs", + 0x73717368: "squashfs", + 0x62656572: "sysfs", + 0x012ff7b6: "sysv2", + 0x012ff7b5: "sysv4", + 0x01021994: "tmpfs", + 0x74726163: "tracefs", + 0x15013346: "udf", + 0x00011954: "ufs", + 0x9fa2: "usbdevice", + 0x01021997: "v9fs", + 0xa501fcf5: "vxfs", + 0xabba1974: "xenfs", + 0x012ff7b4: "xenix", + 0x58465342: "xfs", + 0x2fc12fc1: "zfs", +} + +func GetFSType(path string) (string, error) { + var buf unix.Statfs_t + + err := unix.Statfs(path, &buf) + + if err != nil { + return "", err + } + + fsType, ok := fsTypeMapping[int64(buf.Type)] //nolint:unconvert + + if !ok { + return "", fmt.Errorf("unknown fstype %d", buf.Type) + } + + return fsType, nil +} diff --git a/pkg/types/getfstype_freebsd.go b/pkg/types/getfstype_freebsd.go new file mode 100644 index 00000000000..8fbe3dd7cc4 --- /dev/null +++ b/pkg/types/getfstype_freebsd.go @@ -0,0 +1,25 @@ +//go:build freebsd + +package types + +import ( + "fmt" + "syscall" +) + +func GetFSType(path string) (string, error) { + var fsStat syscall.Statfs_t + + if err := syscall.Statfs(path, &fsStat); err != nil { + return "", fmt.Errorf("failed to get filesystem type: %w", err) + } + + bs := fsStat.Fstypename + + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + + return string(b), nil +} diff --git a/pkg/types/getfstype_openbsd.go b/pkg/types/getfstype_openbsd.go new file mode 100644 index 00000000000..9ec254b7bec --- /dev/null +++ b/pkg/types/getfstype_openbsd.go @@ -0,0 +1,25 @@ +//go:build openbsd + +package types + +import ( + "fmt" + "syscall" +) + +func GetFSType(path string) (string, error) { + var fsStat syscall.Statfs_t + + if err := syscall.Statfs(path, &fsStat); err != nil { + return "", fmt.Errorf("failed to get filesystem type: %w", err) + } + + bs := fsStat.F_fstypename + + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + + return string(b), nil +} diff --git a/pkg/types/getfstype_windows.go b/pkg/types/getfstype_windows.go new file mode 100644 index 00000000000..03d8fffd48d --- /dev/null +++ b/pkg/types/getfstype_windows.go @@ -0,0 +1,53 @@ +package types + +import ( + "path/filepath" + "syscall" + "unsafe" +) + +func GetFSType(path string) (string, error) { + kernel32, err := syscall.LoadLibrary("kernel32.dll") + if err != nil { + return "", err + } + defer syscall.FreeLibrary(kernel32) + + getVolumeInformation, err := syscall.GetProcAddress(kernel32, "GetVolumeInformationW") + if err != nil { + return "", err + } + + // Convert relative path to absolute path + absPath, err := filepath.Abs(path) + if err != nil { + return "", err + } + + // Get the root path of the volume + volumeRoot := filepath.VolumeName(absPath) + "\\" + + volumeRootPtr, _ := syscall.UTF16PtrFromString(volumeRoot) + + var ( + fileSystemNameBuffer = make([]uint16, 260) + nFileSystemNameSize = uint32(len(fileSystemNameBuffer)) + ) + + ret, _, err := syscall.SyscallN(getVolumeInformation, + uintptr(unsafe.Pointer(volumeRootPtr)), + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&fileSystemNameBuffer[0])), + uintptr(nFileSystemNameSize), + 0) + + if ret == 0 { + return "", err + } + + return syscall.UTF16ToString(fileSystemNameBuffer), nil +} diff --git a/pkg/types/ip.go b/pkg/types/ip.go index 5e4d7734f2d..9d08afd8809 100644 --- a/pkg/types/ip.go +++ b/pkg/types/ip.go @@ -2,6 +2,7 @@ package types import ( "encoding/binary" + "errors" "fmt" "math" "net" @@ -15,6 +16,7 @@ func LastAddress(n net.IPNet) net.IP { if ip == nil { // IPv6 ip = n.IP + return net.IP{ ip[0] | ^n.Mask[0], ip[1] | ^n.Mask[1], ip[2] | ^n.Mask[2], ip[3] | ^n.Mask[3], ip[4] | ^n.Mask[4], ip[5] | ^n.Mask[5], @@ -38,12 +40,13 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("while parsing range %s: %w", anyIP, err) } + return Range2Ints(*net) } ip := net.ParseIP(anyIP) if ip == nil { - return -1, 0, 0, 0, 0, fmt.Errorf("invalid address") + return -1, 0, 0, 0, 0, errors.New("invalid address") } sz, start, end, err := IP2Ints(ip) @@ -56,19 +59,22 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) { /*size (16|4), nw_start, suffix_start, nw_end, suffix_end, error*/ func Range2Ints(network net.IPNet) (int, int64, int64, int64, int64, error) { - szStart, nwStart, sfxStart, err := IP2Ints(network.IP) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("converting first ip in range: %w", err) } + lastAddr := LastAddress(network) + szEnd, nwEnd, sfxEnd, err := IP2Ints(lastAddr) if err != nil { return -1, 0, 0, 0, 0, fmt.Errorf("transforming last address of range: %w", err) } + if szEnd != szStart { return -1, 0, 0, 0, 0, fmt.Errorf("inconsistent size for range first(%d) and last(%d) ip", szStart, szEnd) } + return szStart, nwStart, sfxStart, nwEnd, sfxEnd, nil } @@ -85,6 +91,7 @@ func uint2int(u uint64) int64 { ret = int64(u) ret -= math.MaxInt64 } + return ret } @@ -97,13 +104,15 @@ func IP2Ints(pip net.IP) (int, int64, int64, error) { if pip4 != nil { ip_nw32 := binary.BigEndian.Uint32(pip4) - return 4, uint2int(uint64(ip_nw32)), uint2int(ip_sfx), nil - } else if pip16 != nil { + } + + if pip16 != nil { ip_nw = binary.BigEndian.Uint64(pip16[0:8]) ip_sfx = binary.BigEndian.Uint64(pip16[8:16]) + return 16, uint2int(ip_nw), uint2int(ip_sfx), nil - } else { - return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) } + + return -1, 0, 0, fmt.Errorf("unexpected len %d for %s", len(pip), pip) } diff --git a/pkg/types/queue.go b/pkg/types/queue.go index d9b737d548f..12a3ab37074 100644 --- a/pkg/types/queue.go +++ b/pkg/types/queue.go @@ -22,7 +22,7 @@ func NewQueue(l int) *Queue { Queue: make([]Event, 0, l), L: l, } - log.WithFields(log.Fields{"Capacity": q.L}).Debugf("Creating queue") + log.WithField("Capacity", q.L).Debugf("Creating queue") return q } diff --git a/pkg/types/utils.go b/pkg/types/utils.go index e42c36d8aeb..712d44ba12d 100644 --- a/pkg/types/utils.go +++ b/pkg/types/utils.go @@ -3,6 +3,7 @@ package types import ( "fmt" "path/filepath" + "strings" "time" log "github.com/sirupsen/logrus" @@ -67,3 +68,12 @@ func ConfigureLogger(clog *log.Logger) error { func UtcNow() time.Time { return time.Now().UTC() } + +func IsNetworkFS(path string) (bool, string, error) { + fsType, err := GetFSType(path) + if err != nil { + return false, "", err + } + fsType = strings.ToLower(fsType) + return fsType == "nfs" || fsType == "cifs" || fsType == "smb" || fsType == "smb2", fsType, nil +} diff --git a/rpm/SOURCES/crowdsec.unit.patch b/rpm/SOURCES/crowdsec.unit.patch deleted file mode 100644 index af9fe5c31e3..00000000000 --- a/rpm/SOURCES/crowdsec.unit.patch +++ /dev/null @@ -1,13 +0,0 @@ ---- config/crowdsec.service-orig 2022-03-24 09:46:16.581681532 +0000 -+++ config/crowdsec.service 2022-03-24 09:46:28.761681532 +0000 -@@ -5,8 +5,8 @@ - [Service] - Type=notify - Environment=LC_ALL=C LANG=C --ExecStartPre=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error --ExecStart=/usr/local/bin/crowdsec -c /etc/crowdsec/config.yaml -+ExecStartPre=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml -t -error -+ExecStart=/usr/bin/crowdsec -c /etc/crowdsec/config.yaml - #ExecStartPost=/bin/sleep 0.1 - ExecReload=/bin/kill -HUP $MAINPID - Restart=always diff --git a/rpm/SPECS/crowdsec.spec b/rpm/SPECS/crowdsec.spec index f14df932590..ab71b650d11 100644 --- a/rpm/SPECS/crowdsec.spec +++ b/rpm/SPECS/crowdsec.spec @@ -8,8 +8,7 @@ License: MIT URL: https://crowdsec.net Source0: https://github.com/crowdsecurity/%{name}/archive/v%(echo $VERSION).tar.gz Source1: 80-%{name}.preset -Patch0: crowdsec.unit.patch -Patch1: user.patch +Patch0: user.patch BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n) BuildRequires: systemd @@ -32,13 +31,13 @@ Requires: crontabs %setup -q -T -b 0 %patch0 -%patch1 %build sed -i "s#/usr/local/lib/crowdsec/plugins/#%{_libdir}/%{name}/plugins/#g" config/config.yaml %install rm -rf %{buildroot} +mkdir -p %{buildroot}/etc/crowdsec/acquis.d mkdir -p %{buildroot}/etc/crowdsec/hub mkdir -p %{buildroot}/etc/crowdsec/patterns mkdir -p %{buildroot}/etc/crowdsec/console/ @@ -53,7 +52,7 @@ mkdir -p %{buildroot}%{_libdir}/%{name}/plugins/ install -m 755 -D cmd/crowdsec/crowdsec %{buildroot}%{_bindir}/%{name} install -m 755 -D cmd/crowdsec-cli/cscli %{buildroot}%{_bindir}/cscli install -m 755 -D wizard.sh %{buildroot}/usr/share/crowdsec/wizard.sh -install -m 644 -D config/crowdsec.service %{buildroot}%{_unitdir}/%{name}.service +install -m 644 -D debian/crowdsec.service %{buildroot}%{_unitdir}/%{name}.service install -m 644 -D config/patterns/* -t %{buildroot}%{_sysconfdir}/crowdsec/patterns install -m 600 -D config/config.yaml %{buildroot}%{_sysconfdir}/crowdsec install -m 644 -D config/simulation.yaml %{buildroot}%{_sysconfdir}/crowdsec @@ -68,13 +67,14 @@ install -m 551 cmd/notification-http/notification-http %{buildroot}%{_libdir}/%{ install -m 551 cmd/notification-splunk/notification-splunk %{buildroot}%{_libdir}/%{name}/plugins/ install -m 551 cmd/notification-email/notification-email %{buildroot}%{_libdir}/%{name}/plugins/ install -m 551 cmd/notification-sentinel/notification-sentinel %{buildroot}%{_libdir}/%{name}/plugins/ +install -m 551 cmd/notification-file/notification-file %{buildroot}%{_libdir}/%{name}/plugins/ install -m 600 cmd/notification-slack/slack.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ install -m 600 cmd/notification-http/http.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ install -m 600 cmd/notification-splunk/splunk.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ install -m 600 cmd/notification-email/email.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ install -m 600 cmd/notification-sentinel/sentinel.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ - +install -m 600 cmd/notification-file/file.yaml %{buildroot}%{_sysconfdir}/crowdsec/notifications/ %clean rm -rf %{buildroot} @@ -89,6 +89,7 @@ rm -rf %{buildroot} %{_libdir}/%{name}/plugins/notification-splunk %{_libdir}/%{name}/plugins/notification-email %{_libdir}/%{name}/plugins/notification-sentinel +%{_libdir}/%{name}/plugins/notification-file %{_sysconfdir}/%{name}/patterns/linux-syslog %{_sysconfdir}/%{name}/patterns/ruby %{_sysconfdir}/%{name}/patterns/nginx @@ -124,6 +125,7 @@ rm -rf %{buildroot} %config(noreplace) %{_sysconfdir}/%{name}/notifications/splunk.yaml %config(noreplace) %{_sysconfdir}/%{name}/notifications/email.yaml %config(noreplace) %{_sysconfdir}/%{name}/notifications/sentinel.yaml +%config(noreplace) %{_sysconfdir}/%{name}/notifications/file.yaml %config(noreplace) %{_sysconfdir}/cron.daily/%{name} %{_unitdir}/%{name}.service diff --git a/test/README.md b/test/README.md index 723ee5d3e9b..f7b036e7905 100644 --- a/test/README.md +++ b/test/README.md @@ -61,8 +61,6 @@ architectures. - `curl` - `daemonize` - `jq` - - `nc` - - `openssl` - `python3` ## Running all tests @@ -241,6 +239,11 @@ according to the specific needs of the group of tests in the file. crowdsec instance. Crowdsec must not be running while this operation is performed. + - instance-data lock/unlock + +When playing around with a local crowdsec installation, you can run "instance-data lock" +to prevent the bats suite from running, so it won't overwrite your configuration or data. + - `instance-crowdsec [ start | stop ]` Runs (or stops) crowdsec as a background process. PID and lockfiles are @@ -412,10 +415,3 @@ different syntax. Check the heredocs (the < 0 + - name: Replace old text with new text + become: true + ansible.builtin.replace: + path: "{{ item.path }}" + regexp: '#baseurl=http://mirror.centos.org' + replace: 'baseurl=https://vault.centos.org' + loop: "{{ repo_files.files }}" + when: + - ansible_facts.distribution == "CentOS" + - ansible_facts.distribution_major_version == '8' + - repo_files.matched > 0 + - name: "Install required packages" hosts: all vars_files: @@ -17,6 +51,19 @@ - crowdsecurity.testing.re2 - crowdsecurity.testing.bats_requirements +- name: "Install recent python" + hosts: all + vars_files: + - vars/python.yml + tasks: + - name: role "crowdsecurity.testing.python3" + ansible.builtin.include_role: + name: crowdsecurity.testing.python3 + when: + - ansible_facts.distribution in ['CentOS', 'OracleLinux'] + - ansible_facts.distribution_major_version == '8' or ansible_facts.distribution_major_version == '7' + + - name: "Install Postgres" hosts: all become: true diff --git a/test/ansible/requirements.yml b/test/ansible/requirements.yml index a780e827f85..d5a9b80f659 100644 --- a/test/ansible/requirements.yml +++ b/test/ansible/requirements.yml @@ -14,7 +14,7 @@ collections: - name: ansible.posix - name: https://github.com/crowdsecurity/ansible-collection-crowdsecurity.testing.git type: git - version: v0.0.5 + version: v0.0.7 # - name: crowdsecurity.testing # source: ../../../crowdsecurity.testing diff --git a/test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile b/test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile similarity index 84% rename from test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile rename to test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile index 4a3ec307c4f..f2dc70816c9 100644 --- a/test/ansible/vagrant/experimental/opensuse-15.4/Vagrantfile +++ b/test/ansible/vagrant/experimental/opensuse-15.6/Vagrantfile @@ -1,7 +1,8 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - config.vm.box = 'opensuse/Leap-15.4.x86_64' + config.vm.box = 'opensuse/Leap-15.6.x86_64' + config.vm.box_version = "15.6.13.280" config.vm.define 'crowdsec' config.vm.provision 'shell', path: 'bootstrap' diff --git a/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap b/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap new file mode 100644 index 00000000000..a43165d1828 --- /dev/null +++ b/test/ansible/vagrant/experimental/opensuse-15.6/bootstrap @@ -0,0 +1,3 @@ +#!/bin/sh + +zypper install -y kitty-terminfo diff --git a/test/ansible/vagrant/fedora-33/skip b/test/ansible/vagrant/fedora-37/skip old mode 100755 new mode 100644 similarity index 100% rename from test/ansible/vagrant/fedora-33/skip rename to test/ansible/vagrant/fedora-37/skip diff --git a/test/ansible/vagrant/fedora-34/skip b/test/ansible/vagrant/fedora-38/skip old mode 100755 new mode 100644 similarity index 100% rename from test/ansible/vagrant/fedora-34/skip rename to test/ansible/vagrant/fedora-38/skip diff --git a/test/ansible/vagrant/fedora-33/Vagrantfile b/test/ansible/vagrant/fedora-39/Vagrantfile similarity index 69% rename from test/ansible/vagrant/fedora-33/Vagrantfile rename to test/ansible/vagrant/fedora-39/Vagrantfile index df6f06944ae..ec03661fe39 100644 --- a/test/ansible/vagrant/fedora-33/Vagrantfile +++ b/test/ansible/vagrant/fedora-39/Vagrantfile @@ -1,8 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - # config.vm.box = "fedora/33-cloud-base" - config.vm.box = 'generic/fedora33' + config.vm.box = "fedora/39-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-39/skip b/test/ansible/vagrant/fedora-39/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-39/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/fedora-34/Vagrantfile b/test/ansible/vagrant/fedora-40/Vagrantfile similarity index 69% rename from test/ansible/vagrant/fedora-34/Vagrantfile rename to test/ansible/vagrant/fedora-40/Vagrantfile index db2db8d0879..ec03661fe39 100644 --- a/test/ansible/vagrant/fedora-34/Vagrantfile +++ b/test/ansible/vagrant/fedora-40/Vagrantfile @@ -1,8 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - # config.vm.box = "fedora/34-cloud-base" - config.vm.box = 'generic/fedora34' + config.vm.box = "fedora/39-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-40/skip b/test/ansible/vagrant/fedora-40/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-40/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile b/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile new file mode 100644 index 00000000000..52490900fd8 --- /dev/null +++ b/test/ansible/vagrant/ubuntu-24-04-noble/Vagrantfile @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = 'alvistack/ubuntu-24.04' + config.vm.provision "shell", inline: <<-SHELL + SHELL +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vars/python.yml b/test/ansible/vars/python.yml new file mode 100644 index 00000000000..0cafdcc3d4c --- /dev/null +++ b/test/ansible/vars/python.yml @@ -0,0 +1 @@ +python_version: "3.12.3" diff --git a/test/bats.mk b/test/bats.mk index 0cc5deb9b7a..72ac8863f72 100644 --- a/test/bats.mk +++ b/test/bats.mk @@ -38,6 +38,7 @@ define ENV := export TEST_DIR="$(TEST_DIR)" export LOCAL_DIR="$(LOCAL_DIR)" export BIN_DIR="$(BIN_DIR)" +# append .min to the binary names to use the minimal profile export CROWDSEC="$(CROWDSEC)" export CSCLI="$(CSCLI)" export CONFIG_YAML="$(CONFIG_DIR)/config.yaml" @@ -66,15 +67,20 @@ bats-check-requirements: ## Check dependencies for functional tests @$(TEST_DIR)/bin/check-requirements bats-update-tools: ## Install/update tools required for functional tests - # yq v4.40.4 - GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@1c3d55106075bd37df197b4bc03cb4a413fdb903 - # cfssl v1.6.4 - GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssl@b4d0d877cac528f63db39dfb62d5c96cd3a32a0b - GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssljson@b4d0d877cac528f63db39dfb62d5c96cd3a32a0b + # yq v4.44.3 + GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@bbdd97482f2d439126582a59689eb1c855944955 + # cfssl v1.6.5 + GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssl@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda + GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssljson@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda # Build and installs crowdsec in a local directory. Rebuilds if already exists. bats-build: bats-environment ## Build binaries for functional tests @$(MKDIR) $(BIN_DIR) $(LOG_DIR) $(PID_DIR) $(BATS_PLUGIN_DIR) + # minimal profile + @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) BUILD_PROFILE=minimal + @install -m 0755 cmd/crowdsec/crowdsec $(BIN_DIR)/crowdsec.min + @install -m 0755 cmd/crowdsec-cli/cscli $(BIN_DIR)/cscli.min + # default profile @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) @install -m 0755 cmd/crowdsec/crowdsec cmd/crowdsec-cli/cscli $(BIN_DIR)/ @install -m 0755 cmd/notification-*/notification-* $(BATS_PLUGIN_DIR)/ diff --git a/test/bats/00_wait_for.bats b/test/bats/00_wait_for.bats index ffc6802d9bc..94c65033bb4 100644 --- a/test/bats/00_wait_for.bats +++ b/test/bats/00_wait_for.bats @@ -68,4 +68,3 @@ setup() { 2 EOT } - diff --git a/test/bats/01_crowdsec.bats b/test/bats/01_crowdsec.bats index be06ac9261a..aa5830a6bae 100644 --- a/test/bats/01_crowdsec.bats +++ b/test/bats/01_crowdsec.bats @@ -24,8 +24,8 @@ teardown() { #---------- @test "crowdsec (usage)" { - rune -0 wait-for --out "Usage of " "${CROWDSEC}" -h - rune -0 wait-for --out "Usage of " "${CROWDSEC}" --help + rune -0 wait-for --out "Usage of " "$CROWDSEC" -h + rune -0 wait-for --out "Usage of " "$CROWDSEC" --help } @test "crowdsec (unknown flag)" { @@ -33,19 +33,24 @@ teardown() { } @test "crowdsec (unknown argument)" { - rune -0 wait-for --err "argument provided but not defined: trololo" "${CROWDSEC}" trololo + rune -0 wait-for --err "argument provided but not defined: trololo" "$CROWDSEC" trololo +} + +@test "crowdsec -version" { + rune -0 "$CROWDSEC" -version + assert_output --partial "version:" } @test "crowdsec (no api and no agent)" { rune -0 wait-for \ - --err "You must run at least the API Server or crowdsec" \ - "${CROWDSEC}" -no-api -no-cs + --err "you must run at least the API Server or crowdsec" \ + "$CROWDSEC" -no-api -no-cs } @test "crowdsec - print error on exit" { # errors that cause program termination are printed to stderr, not only logs config_set '.db_config.type="meh"' - rune -1 "${CROWDSEC}" + rune -1 "$CROWDSEC" assert_stderr --partial "unable to create database client: unknown database type 'meh'" } @@ -53,32 +58,35 @@ teardown() { config_set '.common={}' rune -0 wait-for \ --err "Starting processing data" \ - "${CROWDSEC}" + "$CROWDSEC" refute_output config_set 'del(.common)' rune -0 wait-for \ --err "Starting processing data" \ - "${CROWDSEC}" + "$CROWDSEC" refute_output } @test "CS_LAPI_SECRET not strong enough" { - CS_LAPI_SECRET=foo rune -1 wait-for "${CROWDSEC}" + CS_LAPI_SECRET=foo rune -1 wait-for "$CROWDSEC" assert_stderr --partial "api server init: unable to run local API: controller init: CS_LAPI_SECRET not strong enough" } @test "crowdsec - reload (change of logfile, disabled agent)" { - logdir1=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + logdir1=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) log_old="${logdir1}/crowdsec.log" config_set ".common.log_dir=\"${logdir1}\"" rune -0 ./instance-crowdsec start-pid PID="$output" + + sleep .5 + assert_file_exists "$log_old" assert_file_contains "$log_old" "Starting processing data" - logdir2=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + logdir2=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) log_new="${logdir2}/crowdsec.log" config_set ".common.log_dir=\"${logdir2}\"" @@ -134,7 +142,7 @@ teardown() { ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') rm -f "$ACQUIS_YAML" - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "acquis.yaml: no such file or directory" } @@ -144,10 +152,10 @@ teardown() { config_set '.crowdsec_service.acquisition_path=""' ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') - rm -f "$ACQUIS_DIR" + rm -rf "$ACQUIS_DIR" config_set '.common.log_media="stdout"' - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" # check warning assert_stderr --partial "no acquisition file found" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" @@ -159,11 +167,11 @@ teardown() { config_set '.crowdsec_service.acquisition_path=""' ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') - rm -f "$ACQUIS_DIR" + rm -rf "$ACQUIS_DIR" config_set '.crowdsec_service.acquisition_dir=""' config_set '.common.log_media="stdout"' - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" # check warning assert_stderr --partial "no acquisition_path or acquisition_dir specified" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" @@ -181,17 +189,52 @@ teardown() { rune -0 wait-for \ --err "Starting processing data" \ - "${CROWDSEC}" + "$CROWDSEC" # now, if foo.yaml is empty instead, there won't be valid datasources. cat /dev/null >"$ACQUIS_DIR"/foo.yaml - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } -@test "crowdsec (disabled datasources)" { +@test "crowdsec (datasource not built)" { + config_set '.common.log_media="stdout"' + + # a datasource cannot run - it's not built in the log processor executable + + ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') + mkdir -p "$ACQUIS_DIR" + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: journalctl + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min + + # auto-detection of journalctl_filter still works + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: whatever + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min +} + +@test "crowdsec (disabled datasource)" { if is_package_testing; then # we can't hide journalctl in package testing # because crowdsec is run from systemd @@ -214,8 +257,8 @@ teardown() { #shellcheck disable=SC2016 rune -0 wait-for \ - --err 'datasource '\''journalctl'\'' is not available: exec: "journalctl": executable file not found in ' \ - env PATH='' "${CROWDSEC}" + --err 'datasource '\''journalctl'\'' is not available: exec: \\"journalctl\\": executable file not found in ' \ + env PATH='' "$CROWDSEC" # if all datasources are disabled, crowdsec should exit @@ -223,7 +266,7 @@ teardown() { rm -f "$ACQUIS_YAML" config_set '.crowdsec_service.acquisition_path=""' - rune -1 wait-for env PATH='' "${CROWDSEC}" + rune -1 wait-for env PATH='' "$CROWDSEC" assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } @@ -234,11 +277,11 @@ teardown() { # if filenames are missing, it won't be able to detect source type config_set "$ACQUIS_YAML" '.source="file"' - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "failed to configure datasource file: no filename or filenames configuration provided" config_set "$ACQUIS_YAML" '.filenames=["file.log"]' config_set "$ACQUIS_YAML" '.meh=3' - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "field meh not found in type fileacquisition.FileConfiguration" } diff --git a/test/bats/01_crowdsec_lapi.bats b/test/bats/01_crowdsec_lapi.bats index 4819d724fea..21e1d7a093e 100644 --- a/test/bats/01_crowdsec_lapi.bats +++ b/test/bats/01_crowdsec_lapi.bats @@ -27,25 +27,24 @@ teardown() { @test "lapi (.api.server.enable=false)" { rune -0 config_set '.api.server.enable=false' - rune -1 "${CROWDSEC}" -no-cs - assert_stderr --partial "You must run at least the API Server or crowdsec" + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "you must run at least the API Server or crowdsec" } @test "lapi (no .api.server.listen_uri)" { - rune -0 config_set 'del(.api.server.listen_uri)' - rune -1 "${CROWDSEC}" -no-cs - assert_stderr --partial "no listen_uri specified" + rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)' + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "no listen_uri or listen_socket specified" } @test "lapi (bad .api.server.listen_uri)" { - rune -0 config_set '.api.server.listen_uri="127.0.0.1:-80"' - rune -1 "${CROWDSEC}" -no-cs - assert_stderr --partial "while starting API server: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port" + rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:-80"' + rune -1 "$CROWDSEC" -no-cs + assert_stderr --partial "local API server stopped with error: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port" } @test "lapi (listen on random port)" { config_set '.common.log_media="stdout"' - rune -0 config_set '.api.server.listen_uri="127.0.0.1:0"' - rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs + rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:0"' + rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "$CROWDSEC" -no-cs } - diff --git a/test/bats/01_cscli.bats b/test/bats/01_cscli.bats index 3a5b4aad04c..264870501a5 100644 --- a/test/bats/01_cscli.bats +++ b/test/bats/01_cscli.bats @@ -40,20 +40,20 @@ teardown() { @test "cscli version" { rune -0 cscli version - assert_stderr --partial "version:" - assert_stderr --partial "Codename:" - assert_stderr --partial "BuildDate:" - assert_stderr --partial "GoVersion:" - assert_stderr --partial "Platform:" - assert_stderr --partial "Constraint_parser:" - assert_stderr --partial "Constraint_scenario:" - assert_stderr --partial "Constraint_api:" - assert_stderr --partial "Constraint_acquis:" + assert_output --partial "version:" + assert_output --partial "Codename:" + assert_output --partial "BuildDate:" + assert_output --partial "GoVersion:" + assert_output --partial "Platform:" + assert_output --partial "Constraint_parser:" + assert_output --partial "Constraint_scenario:" + assert_output --partial "Constraint_api:" + assert_output --partial "Constraint_acquis:" # should work without configuration file - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli version - assert_stderr --partial "version:" + assert_output --partial "version:" } @test "cscli help" { @@ -62,7 +62,7 @@ teardown() { assert_line --regexp ".* help .* Help about any command" # should work without configuration file - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli help assert_line "Available Commands:" } @@ -100,10 +100,14 @@ teardown() { # check that LAPI configuration is loaded (human and json, not shows in raw) + sock=$(config_get '.api.server.listen_socket') + rune -0 cscli config show -o human assert_line --regexp ".*- URL +: http://127.0.0.1:8080/" assert_line --regexp ".*- Login +: githubciXXXXXXXXXXXXXXXXXXXXXXXX([a-zA-Z0-9]{16})?" assert_line --regexp ".*- Credentials File +: .*/local_api_credentials.yaml" + assert_line --regexp ".*- Listen URL +: 127.0.0.1:8080" + assert_line --regexp ".*- Listen Socket +: $sock" rune -0 cscli config show -o json rune -0 jq -c '.API.Client.Credentials | [.url,.login[0:32]]' <(output) @@ -126,9 +130,8 @@ teardown() { EOT } - @test "cscli - required configuration paths" { - config=$(cat "${CONFIG_YAML}") + config=$(cat "$CONFIG_YAML") configdir=$(config_get '.config_paths.config_dir') # required configuration paths with no defaults @@ -136,12 +139,12 @@ teardown() { config_set 'del(.config_paths)' rune -1 cscli hub list assert_stderr --partial 'no configuration paths provided' - echo "$config" > "${CONFIG_YAML}" + echo "$config" > "$CONFIG_YAML" config_set 'del(.config_paths.data_dir)' rune -1 cscli hub list assert_stderr --partial "please provide a data directory with the 'data_dir' directive in the 'config_paths' section" - echo "$config" > "${CONFIG_YAML}" + echo "$config" > "$CONFIG_YAML" # defaults @@ -149,13 +152,13 @@ teardown() { rune -0 cscli hub list rune -0 cscli config show --key Config.ConfigPaths.HubDir assert_output "$configdir/hub" - echo "$config" > "${CONFIG_YAML}" + echo "$config" > "$CONFIG_YAML" config_set 'del(.config_paths.index_path)' rune -0 cscli hub list rune -0 cscli config show --key Config.ConfigPaths.HubIndexFile assert_output "$configdir/hub/.index.json" - echo "$config" > "${CONFIG_YAML}" + echo "$config" > "$CONFIG_YAML" } @test "cscli config show-yaml" { @@ -178,110 +181,34 @@ teardown() { assert_stderr --partial "failed to backup config: while creating /dev/null/blah: mkdir /dev/null/blah: not a directory" # pick a dirpath - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) # succeed the first time - rune -0 cscli config backup "${backupdir}" + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" # don't overwrite an existing backup - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" SIMULATION_YAML="$(config_get '.config_paths.simulation_path')" # restore - rm "${SIMULATION_YAML}" - rune -0 cscli config restore "${backupdir}" - assert_file_exists "${SIMULATION_YAML}" + rm "$SIMULATION_YAML" + rune -0 cscli config restore "$backupdir" + assert_file_exists "$SIMULATION_YAML" # cleanup rm -rf -- "${backupdir:?}" # backup: detect missing files - rm "${SIMULATION_YAML}" - rune -1 cscli config backup "${backupdir}" + rm "$SIMULATION_YAML" + rune -1 cscli config backup "$backupdir" assert_stderr --regexp "failed to backup config: failed copy .* to .*: stat .*: no such file or directory" rm -rf -- "${backupdir:?}" } -@test "cscli lapi status" { - rune -0 ./instance-crowdsec start - rune -0 cscli lapi status - - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial " on http://127.0.0.1:8080/" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" -} - -@test "cscli - missing LAPI credentials file" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - rm -f "${LOCAL_API_CREDENTIALS}" - rune -1 cscli lapi status - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" - - rune -1 cscli alerts list - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" - - rune -1 cscli decisions list - assert_stderr --partial "loading api client: while reading yaml file: open ${LOCAL_API_CREDENTIALS}: no such file or directory" -} - -@test "cscli - empty LAPI credentials file" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - : > "${LOCAL_API_CREDENTIALS}" - rune -1 cscli lapi status - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" - - rune -1 cscli alerts list - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" - - rune -1 cscli decisions list - assert_stderr --partial "no credentials or URL found in api client configuration '${LOCAL_API_CREDENTIALS}'" -} - -@test "cscli - missing LAPI client settings" { - config_set 'del(.api.client)' - rune -1 cscli lapi status - assert_stderr --partial "loading api client: no API client section in configuration" - - rune -1 cscli alerts list - assert_stderr --partial "loading api client: no API client section in configuration" - - rune -1 cscli decisions list - assert_stderr --partial "loading api client: no API client section in configuration" -} - -@test "cscli - malformed LAPI url" { - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - config_set "${LOCAL_API_CREDENTIALS}" '.url="http://127.0.0.1:-80"' - - rune -1 cscli lapi status -o json - rune -0 jq -r '.msg' <(stderr) - assert_output 'parsing api url: parse "http://127.0.0.1:-80/": invalid port ":-80" after host' -} - -@test "cscli - bad LAPI password" { - rune -0 ./instance-crowdsec start - LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') - config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"' - - rune -1 cscli lapi status -o json - rune -0 jq -r '.msg' <(stderr) - assert_output 'failed to authenticate to Local API (LAPI): API error: incorrect Username or Password' -} - -@test "cscli metrics" { - rune -0 ./instance-crowdsec start - rune -0 cscli lapi status - rune -0 cscli metrics - assert_output --partial "Route" - assert_output --partial '/v1/watchers/login' - assert_output --partial "Local API Metrics:" -} - @test "'cscli completion' with or without configuration file" { rune -0 cscli completion bash assert_output --partial "# bash completion for cscli" @@ -292,7 +219,7 @@ teardown() { rune -0 cscli completion fish assert_output --partial "# fish completion for cscli" - rm "${CONFIG_YAML}" + rm "$CONFIG_YAML" rune -0 cscli completion bash assert_output --partial "# bash completion for cscli" } @@ -336,16 +263,14 @@ teardown() { } @test "cscli doc" { - # generating documentation requires a directory named "doc" - cd "$BATS_TEST_TMPDIR" rune -1 cscli doc refute_output - assert_stderr --regexp 'failed to generate cobra doc: open doc/.*: no such file or directory' + assert_stderr --regexp 'failed to generate cscli documentation: open doc/.*: no such file or directory' mkdir -p doc rune -0 cscli doc - refute_output + assert_output "Documentation generated in ./doc" refute_stderr assert_file_exists "doc/cscli.md" assert_file_not_exist "doc/cscli_setup.md" @@ -355,6 +280,14 @@ teardown() { export CROWDSEC_FEATURE_CSCLI_SETUP="true" rune -0 cscli doc assert_file_exists "doc/cscli_setup.md" + + # specify a target directory + mkdir -p "$BATS_TEST_TMPDIR/doc2" + rune -0 cscli doc --target "$BATS_TEST_TMPDIR/doc2" + assert_output "Documentation generated in $BATS_TEST_TMPDIR/doc2" + refute_stderr + assert_file_exists "$BATS_TEST_TMPDIR/doc2/cscli_setup.md" + } @test "feature.yaml for subcommands" { @@ -367,3 +300,24 @@ teardown() { rune -0 cscli setup assert_output --partial 'cscli setup [command]' } + +@test "cscli config feature-flags" { + # disabled + rune -0 cscli config feature-flags + assert_line '✗ cscli_setup: Enable cscli setup command (service detection)' + + # enabled in feature.yaml + CONFIG_DIR=$(dirname "$CONFIG_YAML") + echo ' - cscli_setup' >> "$CONFIG_DIR"/feature.yaml + rune -0 cscli config feature-flags + assert_line '✓ cscli_setup: Enable cscli setup command (service detection)' + + # enabled in environment + # shellcheck disable=SC2031 + export CROWDSEC_FEATURE_CSCLI_SETUP="true" + rune -0 cscli config feature-flags + assert_line '✓ cscli_setup: Enable cscli setup command (service detection)' + + # there are no retired features + rune -0 cscli config feature-flags --retired +} diff --git a/test/bats/01_cscli_lapi.bats b/test/bats/01_cscli_lapi.bats new file mode 100644 index 00000000000..6e876576a6e --- /dev/null +++ b/test/bats/01_cscli_lapi.bats @@ -0,0 +1,213 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + # don't run crowdsec here, not all tests require a running instance +} + +teardown() { + cd "$TEST_DIR" || exit 1 + ./instance-crowdsec stop +} + +#---------- + +@test "cscli lapi status" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "cscli - missing LAPI credentials file" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + rm -f "$LOCAL_API_CREDENTIALS" + rune -1 cscli lapi status + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" + + rune -1 cscli alerts list + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" + + rune -1 cscli decisions list + assert_stderr --partial "loading api client: while reading yaml file: open $LOCAL_API_CREDENTIALS: no such file or directory" +} + +@test "cscli - empty LAPI credentials file" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + : > "$LOCAL_API_CREDENTIALS" + rune -1 cscli lapi status + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" + + rune -1 cscli alerts list + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" + + rune -1 cscli decisions list + assert_stderr --partial "no credentials or URL found in api client configuration '$LOCAL_API_CREDENTIALS'" +} + +@test "cscli - LAPI credentials file can reference env variables" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + URL=$(config_get "$LOCAL_API_CREDENTIALS" '.url') + export URL + LOGIN=$(config_get "$LOCAL_API_CREDENTIALS" '.login') + export LOGIN + PASSWORD=$(config_get "$LOCAL_API_CREDENTIALS" '.password') + export PASSWORD + + # shellcheck disable=SC2016 + echo '{"url":"$URL","login":"$LOGIN","password":"$PASSWORD"}' > "$LOCAL_API_CREDENTIALS".local + + config_set '.crowdsec_service.enable=false' + rune -0 ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --partial "You can successfully interact with Local API (LAPI)" + + rm "$LOCAL_API_CREDENTIALS".local + + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.url="$URL"' + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.login="$LOGIN"' + # shellcheck disable=SC2016 + config_set "$LOCAL_API_CREDENTIALS" '.password="$PASSWORD"' + + rune -0 cscli lapi status + assert_output --partial "You can successfully interact with Local API (LAPI)" + + # but if a variable is not defined, there is no specific error message + unset URL + rune -1 cscli lapi status + # shellcheck disable=SC2016 + assert_stderr --partial 'BaseURL must have a trailing slash' +} + +@test "cscli - missing LAPI client settings" { + config_set 'del(.api.client)' + rune -1 cscli lapi status + assert_stderr --partial "loading api client: no API client section in configuration" + + rune -1 cscli alerts list + assert_stderr --partial "loading api client: no API client section in configuration" + + rune -1 cscli decisions list + assert_stderr --partial "loading api client: no API client section in configuration" +} + +@test "cscli - malformed LAPI url" { + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + config_set "$LOCAL_API_CREDENTIALS" '.url="http://127.0.0.1:-80"' + + rune -1 cscli lapi status -o json + rune -0 jq -r '.msg' <(stderr) + assert_output 'failed to authenticate to Local API (LAPI): parse "http://127.0.0.1:-80/": invalid port ":-80" after host' +} + +@test "cscli - bad LAPI password" { + rune -0 ./instance-crowdsec start + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + config_set "$LOCAL_API_CREDENTIALS" '.password="meh"' + + rune -1 cscli lapi status -o json + rune -0 jq -r '.msg' <(stderr) + assert_output 'failed to authenticate to Local API (LAPI): API error: incorrect Username or Password' +} + +@test "cscli lapi register / machines validate" { + rune -1 cscli lapi register + assert_stderr --partial "connection refused" + + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + + rune -0 ./instance-crowdsec start + rune -0 cscli lapi register + assert_stderr --partial "Successfully registered to Local API" + assert_stderr --partial "Local API credentials written to '$LOCAL_API_CREDENTIALS'" + assert_stderr --partial "Run 'sudo systemctl reload crowdsec' for the new configuration to be effective." + + LOGIN=$(config_get "$LOCAL_API_CREDENTIALS" '.login') + + rune -0 cscli machines inspect "$LOGIN" -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "null" + + rune -0 cscli machines validate "$LOGIN" + + rune -0 cscli machines inspect "$LOGIN" -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --machine" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi register --machine newmachine + rune -0 cscli machines validate newmachine + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --token (ignored)" { + # A token is ignored if the server is not configured with it + rune -1 cscli lapi register --machine newmachine --token meh + assert_stderr --partial "connection refused" + + rune -0 ./instance-crowdsec start + rune -1 cscli lapi register --machine newmachine --token meh + assert_stderr --partial '422 Unprocessable Entity: API error: http code 422, invalid request:' + assert_stderr --partial 'registration_token in body should be at least 32 chars long' + + rune -0 cscli lapi register --machine newmachine --token 12345678901234567890123456789012 + assert_stderr --partial "Successfully registered to Local API" + + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "null" +} + +@test "cscli lapi register --token" { + config_set '.api.server.auto_registration.enabled=true' + config_set '.api.server.auto_registration.token="12345678901234567890123456789012"' + config_set '.api.server.auto_registration.allowed_ranges=["127.0.0.1/32"]' + + rune -0 ./instance-crowdsec start + + rune -1 cscli lapi register --machine malicious --token 123456789012345678901234badtoken + assert_stderr --partial "401 Unauthorized: API error: invalid token for auto registration" + rune -1 cscli machines inspect malicious -o json + assert_stderr --partial "unable to read machine data 'malicious': user 'malicious': user doesn't exist" + + rune -0 cscli lapi register --machine newmachine --token 12345678901234567890123456789012 + assert_stderr --partial "Successfully registered to Local API" + rune -0 cscli machines inspect newmachine -o json + rune -0 jq -r '.isValidated' <(output) + assert_output "true" +} + +@test "cscli lapi register --token (bad source ip)" { + config_set '.api.server.auto_registration.enabled=true' + config_set '.api.server.auto_registration.token="12345678901234567890123456789012"' + config_set '.api.server.auto_registration.allowed_ranges=["127.0.0.2/32"]' + + rune -0 ./instance-crowdsec start + + rune -1 cscli lapi register --machine outofrange --token 12345678901234567890123456789012 + assert_stderr --partial "401 Unauthorized: API error: IP not in allowed range for auto registration" + rune -1 cscli machines inspect outofrange -o json + assert_stderr --partial "unable to read machine data 'outofrange': user 'outofrange': user doesn't exist" +} diff --git a/test/bats/02_nolapi.bats b/test/bats/02_nolapi.bats index f1d810bc166..cefa6d798b4 100644 --- a/test/bats/02_nolapi.bats +++ b/test/bats/02_nolapi.bats @@ -27,12 +27,12 @@ teardown() { config_set '.common.log_media="stdout"' rune -0 wait-for \ --err "CrowdSec Local API listening" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "crowdsec should not run without LAPI (-no-api flag)" { config_set '.common.log_media="stdout"' - rune -1 wait-for "${CROWDSEC}" -no-api + rune -1 wait-for "$CROWDSEC" -no-api } @test "crowdsec should not run without LAPI (no api.server in configuration file)" { @@ -40,7 +40,7 @@ teardown() { config_log_stderr rune -0 wait-for \ --err "crowdsec local API is disabled" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi status shouldn't be ok without api.server" { @@ -68,10 +68,10 @@ teardown() { @test "cscli config backup" { config_disable_lapi - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" rm -rf -- "${backupdir:?}" assert_stderr --partial "failed to backup config" diff --git a/test/bats/03_noagent.bats b/test/bats/03_noagent.bats index e75e375ad1c..6be5101cee2 100644 --- a/test/bats/03_noagent.bats +++ b/test/bats/03_noagent.bats @@ -26,14 +26,14 @@ teardown() { config_set '.common.log_media="stdout"' rune -0 wait-for \ --err "Starting processing data" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "no agent: crowdsec LAPI should run (-no-cs flag)" { config_set '.common.log_media="stdout"' rune -0 wait-for \ --err "CrowdSec Local API listening" \ - "${CROWDSEC}" -no-cs + "$CROWDSEC" -no-cs } @test "no agent: crowdsec LAPI should run (no crowdsec_service in configuration file)" { @@ -41,7 +41,7 @@ teardown() { config_log_stderr rune -0 wait-for \ --err "crowdsec agent is disabled" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "no agent: cscli config show" { @@ -62,10 +62,10 @@ teardown() { @test "no agent: cscli config backup" { config_disable_agent - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" @@ -76,7 +76,7 @@ teardown() { config_disable_agent ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/04_capi.bats b/test/bats/04_capi.bats index d5154c1a0d7..7ba6bfa4428 100644 --- a/test/bats/04_capi.bats +++ b/test/bats/04_capi.bats @@ -46,19 +46,32 @@ setup() { assert_stderr --regexp "no configuration for Central API \(CAPI\) in '$(echo $CONFIG_YAML|sed s#//#/#g)'" } -@test "cscli capi status" { +@test "cscli {capi,papi} status" { ./instance-data load config_enable_capi + + # should not panic with no credentials, but return an error + rune -1 cscli papi status + assert_stderr --partial "the Central API (CAPI) must be configured with 'cscli capi register'" + rune -0 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX rune -1 cscli capi status - assert_stderr --partial "no scenarios installed, abort" + assert_stderr --partial "no scenarios or appsec-rules installed, abort" + + rune -1 cscli papi status + assert_stderr --partial "no PAPI URL in configuration" + + rune -0 cscli console enable console_management + rune -1 cscli papi status + assert_stderr --partial "unable to get PAPI permissions" + assert_stderr --partial "Forbidden for plan" rune -0 cscli scenarios install crowdsecurity/ssh-bf rune -0 cscli capi status - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial " on https://api.crowdsec.net/" - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial " on https://api.crowdsec.net/" + assert_output --partial "You can successfully interact with Central API (CAPI)" } @test "cscli alerts list: receive a community pull when capi is enabled" { @@ -85,7 +98,7 @@ setup() { config_disable_agent ./instance-crowdsec start rune -0 cscli capi status - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "You can successfully interact with Central API (CAPI)" } @test "capi register must be run from lapi" { diff --git a/test/bats/04_nocapi.bats b/test/bats/04_nocapi.bats index 234db182a53..d22a6f0a953 100644 --- a/test/bats/04_nocapi.bats +++ b/test/bats/04_nocapi.bats @@ -27,7 +27,7 @@ teardown() { rune -0 wait-for \ --err "Communication with CrowdSec Central API disabled from args" \ - "${CROWDSEC}" -no-capi + "$CROWDSEC" -no-capi } @test "without capi: crowdsec LAPI should still work" { @@ -35,7 +35,7 @@ teardown() { config_set '.common.log_media="stdout"' rune -0 wait-for \ --err "push and pull to Central API disabled" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "without capi: cscli capi status -> fail" { @@ -53,10 +53,10 @@ teardown() { @test "no agent: cscli config backup" { config_disable_capi - backupdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u) - rune -0 cscli config backup "${backupdir}" + backupdir=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp -u) + rune -0 cscli config backup "$backupdir" assert_stderr --partial "Starting configuration backup" - rune -1 cscli config backup "${backupdir}" + rune -1 cscli config backup "$backupdir" assert_stderr --partial "failed to backup config" assert_stderr --partial "file exists" rm -rf -- "${backupdir:?}" @@ -66,7 +66,7 @@ teardown() { config_disable_capi ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/05_config_yaml_local.bats b/test/bats/05_config_yaml_local.bats index b8b6da117ea..ec7a4201964 100644 --- a/test/bats/05_config_yaml_local.bats +++ b/test/bats/05_config_yaml_local.bats @@ -21,7 +21,7 @@ setup() { load "../lib/setup.sh" ./instance-data load rune -0 config_get '.api.client.credentials_path' - LOCAL_API_CREDENTIALS="${output}" + LOCAL_API_CREDENTIALS="$output" export LOCAL_API_CREDENTIALS } @@ -88,13 +88,13 @@ teardown() { @test "simulation.yaml.local" { rune -0 config_get '.config_paths.simulation_path' refute_output null - SIMULATION="${output}" + SIMULATION="$output" - echo "simulation: off" >"${SIMULATION}" + echo "simulation: off" >"$SIMULATION" rune -0 cscli simulation status -o human assert_stderr --partial "global simulation: disabled" - echo "simulation: on" >"${SIMULATION}" + echo "simulation: on" >"$SIMULATION" rune -0 cscli simulation status -o human assert_stderr --partial "global simulation: enabled" @@ -110,7 +110,7 @@ teardown() { @test "profiles.yaml.local" { rune -0 config_get '.api.server.profiles_path' refute_output null - PROFILES="${output}" + PROFILES="$output" cat <<-EOT >"${PROFILES}.local" name: default_ip_remediation @@ -122,17 +122,17 @@ teardown() { on_success: break EOT - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') - echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" + echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"$ACQUIS_YAML" rune -0 cscli collections install crowdsecurity/sshd rune -0 cscli parsers install crowdsecurity/syslog-logs ./instance-crowdsec start sleep .5 - fake_log >>"${tmpfile}" + fake_log >>"$tmpfile" # this could be simplified, but some systems are slow and we don't want to # wait more than required @@ -141,6 +141,6 @@ teardown() { rune -0 cscli decisions list -o json rune -0 jq --exit-status '.[].decisions[0] | [.value,.type] == ["1.1.1.172","captcha"]' <(output) && break done - rm -f -- "${tmpfile}" - [[ "${status}" -eq 0 ]] || fail "captcha not triggered" + rm -f -- "$tmpfile" + [[ "$status" -eq 0 ]] || fail "captcha not triggered" } diff --git a/test/bats/07_setup.bats b/test/bats/07_setup.bats index 9e3f5533728..f832ac572d2 100644 --- a/test/bats/07_setup.bats +++ b/test/bats/07_setup.bats @@ -819,7 +819,6 @@ update-notifier-motd.timer enabled enabled setup: alsdk al; sdf EOT - assert_output "while unmarshaling setup file: yaml: line 2: could not find expected ':'" + assert_output "while parsing setup file: yaml: line 2: could not find expected ':'" assert_stderr --partial "invalid setup file" } - diff --git a/test/bats/08_metrics.bats b/test/bats/08_metrics.bats index 0275d7fd4a0..e260e667524 100644 --- a/test/bats/08_metrics.bats +++ b/test/bats/08_metrics.bats @@ -23,9 +23,9 @@ teardown() { #---------- @test "cscli metrics (crowdsec not running)" { - rune -1 cscli metrics - # crowdsec is down - assert_stderr --partial 'failed to fetch prometheus metrics: executing GET request for URL \"http://127.0.0.1:6060/metrics\" failed: Get \"http://127.0.0.1:6060/metrics\": dial tcp 127.0.0.1:6060: connect: connection refused' + rune -0 cscli metrics + # crowdsec is down, we won't get an error because some metrics come from the db instead + assert_stderr --partial 'while fetching metrics: executing GET request for URL \"http://127.0.0.1:6060/metrics\" failed: Get \"http://127.0.0.1:6060/metrics\": dial tcp 127.0.0.1:6060: connect: connection refused' } @test "cscli metrics (bad configuration)" { @@ -59,3 +59,45 @@ teardown() { rune -1 cscli metrics assert_stderr --partial "prometheus is not enabled, can't show metrics" } + +@test "cscli metrics" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + rune -0 cscli metrics + assert_output --partial "Route" + assert_output --partial '/v1/watchers/login' + assert_output --partial "Local API Metrics:" + + rune -0 cscli metrics -o json + rune -0 jq 'keys' <(output) + assert_output --partial '"alerts",' + assert_output --partial '"parsers",' +} + +@test "cscli metrics list" { + rune -0 cscli metrics list + assert_output --regexp "Type.*Title.*Description" + + rune -0 cscli metrics list -o json + rune -0 jq -c '.[] | [.type,.title]' <(output) + assert_line '["acquisition","Acquisition Metrics"]' +} + +@test "cscli metrics show" { + rune -0 ./instance-crowdsec start + rune -0 cscli lapi status + + assert_equal "$(cscli metrics)" "$(cscli metrics show)" + + rune -1 cscli metrics show foobar + assert_stderr --partial "unknown metrics type: foobar" + + rune -0 cscli metrics show lapi + assert_output --partial "Local API Metrics:" + assert_output --regexp "Route.*Method.*Hits" + assert_output --regexp "/v1/watchers/login.*POST" + + rune -0 cscli metrics show lapi -o json + rune -0 jq -c '.lapi."/v1/watchers/login" | keys' <(output) + assert_json '["POST"]' +} diff --git a/test/bats/08_metrics_bouncer.bats b/test/bats/08_metrics_bouncer.bats new file mode 100644 index 00000000000..c4dfebbab1d --- /dev/null +++ b/test/bats/08_metrics_bouncer.bats @@ -0,0 +1,527 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + ./instance-data load + ./instance-crowdsec start +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli metrics show bouncers (empty)" { + # this message is given only if we ask explicitly for bouncers + notfound="No bouncer metrics found." + + rune -0 cscli metrics show bouncers + assert_output "$notfound" + + rune -0 cscli metrics list + refute_output "$notfound" +} + +@test "rc usage metrics (empty payload)" { + # a registered bouncer can send metrics for the lapi and console + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing remediation component data"}' +} + +@test "rc usage metrics (bad payload)" { + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + log_processors: [] + EOT + ) + + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + remediation_components.0.utc_startup_timestamp in body is required + EOT + + # validation, like timestamp format + + payload=$(yq -o j '.remediation_components[0].utc_startup_timestamp = "2021-09-01T00:00:00Z"' <<<"$payload") + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 400" + assert_json '{message: "json: cannot unmarshal string into Go struct field AllMetrics.remediation_components of type int64"}' + + payload=$(yq -o j '.remediation_components[0].utc_startup_timestamp = 1707399316' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + refute_output + + payload=$(yq -o j '.remediation_components[0].metrics = [{"meta": {}}]' <<<"$payload") + rune -22 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + remediation_components.0.metrics.0.items in body is required + validation failure list: + remediation_components.0.metrics.0.meta.utc_now_timestamp in body is required + remediation_components.0.metrics.0.meta.window_size_seconds in body is required + EOT +} + +@test "rc usage metrics (good payload)" { + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707399316 + log_processors: [] + EOT + ) + + # bouncers have feature flags too + + payload=$(yq -o j ' + .remediation_components[0].feature_flags = ["huey", "dewey", "louie"] | + .remediation_components[0].os = {"name": "Multics", "version": "MR12.5"} + ' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli bouncer inspect testbouncer -o json + rune -0 yq -o j '[.os,.featureflags]' <(output) + assert_json '["Multics/MR12.5",["huey","dewey","louie"]]' + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707399316, "window_size_seconds":600}, + "items":[ + {"name": "foo", "unit": "pound", "value": 3.1415926}, + {"name": "foo", "unit": "pound", "value": 2.7182818}, + {"name": "foo", "unit": "dogyear", "value": 2.7182818} + ] + } + ] + ' <<<"$payload") + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + # aggregation is ok -- we are truncating, not rounding, because the float is mandated by swagger. + # but without labels the origin string is empty + assert_json '{bouncers:{testbouncer:{"": {foo: {dogyear: 2, pound: 5}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +--------+-----------------+ + | Origin | foo | + | | dogyear | pound | + +--------+---------+-------+ + | Total | 2 | 5 | + +--------+---------+-------+ + EOT + + # some more realistic values, at least for the labels + # we don't use the same now_timestamp or the payload will be silently discarded + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707399916, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 500, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "active_decisions", "unit": "ip", "value": 1, "labels": {"ip_type": "ipv6", "origin": "cscli"}}, + {"name": "dropped", "unit": "byte", "value": 3800, "labels": {"ip_type": "ipv4", "origin": "CAPI"}}, + {"name": "dropped", "unit": "byte", "value": 0, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "dropped", "unit": "byte", "value": 1034, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_cruzit_web_attacks"}}, + {"name": "dropped", "unit": "byte", "value": 3847, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "dropped", "unit": "byte", "value": 380, "labels": {"ip_type": "ipv6", "origin": "cscli"}}, + {"name": "dropped", "unit": "packet", "value": 100, "labels": {"ip_type": "ipv4", "origin": "CAPI"}}, + {"name": "dropped", "unit": "packet", "value": 10, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "dropped", "unit": "packet", "value": 23, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_cruzit_web_attacks"}}, + {"name": "dropped", "unit": "packet", "value": 58, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "dropped", "unit": "packet", "value": 0, "labels": {"ip_type": "ipv4", "origin": "lists:anotherlist"}}, + {"name": "dropped", "unit": "byte", "value": 0, "labels": {"ip_type": "ipv4", "origin": "lists:anotherlist"}}, + {"name": "dropped", "unit": "packet", "value": 0, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + assert_json '{ + "bouncers": { + "testbouncer": { + "": { + "foo": { + "dogyear": 2, + "pound": 5 + } + }, + "CAPI": { + "dropped": { + "byte": 3800, + "packet": 100 + } + }, + "cscli": { + "active_decisions": { + "ip": 1 + }, + "dropped": { + "byte": 380, + "packet": 10 + } + }, + "lists:firehol_cruzit_web_attacks": { + "dropped": { + "byte": 1034, + "packet": 23 + } + }, + "lists:firehol_voipbl": { + "active_decisions": { + "ip": 500 + }, + "dropped": { + "byte": 3847, + "packet": 58 + }, + }, + "lists:anotherlist": { + "dropped": { + "byte": 0, + "packet": 0 + } + } + } + } + }' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------------+------------------+-------------------+-----------------+ + | Origin | active_decisions | dropped | foo | + | | IPs | bytes | packets | dogyear | pound | + +----------------------------------+------------------+---------+---------+---------+-------+ + | CAPI (community blocklist) | - | 3.80k | 100 | - | - | + | cscli (manual decisions) | 1 | 380 | 10 | - | - | + | lists:anotherlist | - | 0 | 0 | - | - | + | lists:firehol_cruzit_web_attacks | - | 1.03k | 23 | - | - | + | lists:firehol_voipbl | 500 | 3.85k | 58 | - | - | + +----------------------------------+------------------+---------+---------+---------+-------+ + | Total | 501 | 9.06k | 191 | 2 | 5 | + +----------------------------------+------------------+---------+---------+---------+-------+ + EOT + + # active_decisions is actually a gauge: values should not be aggregated, keep only the latest one + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 250, "labels": {"ip_type": "ipv4", "origin": "lists:firehol_voipbl"}}, + {"name": "active_decisions", "unit": "ip", "value": 10, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + rune -0 cscli metrics show bouncers -o json + assert_json '{ + "bouncers": { + "testbouncer": { + "": { + "foo": { + "dogyear": 2, + "pound": 5 + } + }, + "CAPI": { + "dropped": { + "byte": 3800, + "packet": 100 + } + }, + "cscli": { + "active_decisions": { + "ip": 10 + }, + "dropped": { + "byte": 380, + "packet": 10 + } + }, + "lists:firehol_cruzit_web_attacks": { + "dropped": { + "byte": 1034, + "packet": 23 + } + }, + "lists:firehol_voipbl": { + "active_decisions": { + "ip": 250 + }, + "dropped": { + "byte": 3847, + "packet": 58 + }, + }, + "lists:anotherlist": { + "dropped": { + "byte": 0, + "packet": 0 + } + } + } + } + }' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------------+------------------+-------------------+-----------------+ + | Origin | active_decisions | dropped | foo | + | | IPs | bytes | packets | dogyear | pound | + +----------------------------------+------------------+---------+---------+---------+-------+ + | CAPI (community blocklist) | - | 3.80k | 100 | - | - | + | cscli (manual decisions) | 10 | 380 | 10 | - | - | + | lists:anotherlist | - | 0 | 0 | - | - | + | lists:firehol_cruzit_web_attacks | - | 1.03k | 23 | - | - | + | lists:firehol_voipbl | 250 | 3.85k | 58 | - | - | + +----------------------------------+------------------+---------+---------+---------+-------+ + | Total | 260 | 9.06k | 191 | 2 | 5 | + +----------------------------------+------------------+---------+---------+---------+-------+ + EOT +} + +@test "rc usage metrics (unknown metrics)" { + # new metrics are introduced in a new bouncer version, unknown by this version of cscli: some are gauges, some are not + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + log_processors: [] + EOT + ) + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707460000, "window_size_seconds":600}, + "items":[ + {"name": "ima_gauge", "unit": "second", "value": 30, "labels": {"origin": "cscli"}}, + {"name": "notagauge", "unit": "inch", "value": 15, "labels": {"origin": "cscli"}} + ] + }, { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "ima_gauge", "unit": "second", "value": 20, "labels": {"origin": "cscli"}}, + {"name": "notagauge", "unit": "inch", "value": 10, "labels": {"origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers: {testbouncer: {cscli: {ima_gauge: {second: 30}, notagauge: {inch: 25}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-09 03:40:00 +0000 UTC: + +--------------------------+--------+-----------+ + | Origin | ima | notagauge | + | | second | inch | + +--------------------------+--------+-----------+ + | cscli (manual decisions) | 30 | 25 | + +--------------------------+--------+-----------+ + | Total | 30 | 25 | + +--------------------------+--------+-----------+ + EOT +} + +@test "rc usage metrics (ipv4/ipv6)" { + # gauge metrics are not aggregated over time, but they are over ip type + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + log_processors: [] + EOT + ) + + payload=$(yq -o j ' + .remediation_components[0].metrics = [ + { + "meta": {"utc_now_timestamp": 1707460000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 200, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "active_decisions", "unit": "ip", "value": 30, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + }, { + "meta": {"utc_now_timestamp": 1707450000, "window_size_seconds":600}, + "items":[ + {"name": "active_decisions", "unit": "ip", "value": 400, "labels": {"ip_type": "ipv4", "origin": "cscli"}}, + {"name": "active_decisions", "unit": "ip", "value": 50, "labels": {"ip_type": "ipv6", "origin": "cscli"}} + ] + } + ] | + .remediation_components[0].type = "crowdsec-firewall-bouncer" + ' <<<"$payload") + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers: {testbouncer: {cscli: {active_decisions: {ip: 230}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (testbouncer) since 2024-02-09 03:40:00 +0000 UTC: + +--------------------------+------------------+ + | Origin | active_decisions | + | | IPs | + +--------------------------+------------------+ + | cscli (manual decisions) | 230 | + +--------------------------+------------------+ + | Total | 230 | + +--------------------------+------------------+ + EOT +} + +@test "rc usage metrics (multiple bouncers)" { + # multiple bouncers have separate totals and can have different types of metrics and units -> different columns + + API_KEY=$(cscli bouncers add bouncer1 -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707369316 + metrics: + - meta: + utc_now_timestamp: 1707399316 + window_size_seconds: 600 + items: + - name: dropped + unit: byte + value: 1000 + labels: + origin: CAPI + - name: dropped + unit: byte + value: 800 + labels: + origin: lists:somelist + - name: processed + unit: byte + value: 12340 + - name: processed + unit: packet + value: 100 + EOT + ) + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + API_KEY=$(cscli bouncers add bouncer2 -o raw) + export API_KEY + + payload=$(yq -o j <<-EOT + remediation_components: + - version: "v1.0" + utc_startup_timestamp: 1707379316 + metrics: + - meta: + utc_now_timestamp: 1707389316 + window_size_seconds: 600 + items: + - name: dropped + unit: byte + value: 1500 + labels: + origin: lists:somelist + - name: dropped + unit: byte + value: 2000 + labels: + origin: CAPI + - name: dropped + unit: packet + value: 20 + labels: + origin: lists:somelist + EOT + ) + + rune -0 curl-with-key '/v1/usage-metrics' -X POST --data "$payload" + + rune -0 cscli metrics show bouncers -o json + assert_json '{bouncers:{bouncer1:{"":{processed:{byte:12340,packet:100}},CAPI:{dropped:{byte:1000}},"lists:somelist":{dropped:{byte:800}}},bouncer2:{"lists:somelist":{dropped:{byte:1500,packet:20}},CAPI:{dropped:{byte:2000}}}}}' + + rune -0 cscli metrics show bouncers + assert_output - <<-EOT + Bouncer Metrics (bouncer1) since 2024-02-08 13:35:16 +0000 UTC: + +----------------------------+---------+-----------------------+ + | Origin | dropped | processed | + | | bytes | bytes | packets | + +----------------------------+---------+-----------+-----------+ + | CAPI (community blocklist) | 1.00k | - | - | + | lists:somelist | 800 | - | - | + +----------------------------+---------+-----------+-----------+ + | Total | 1.80k | 12.34k | 100 | + +----------------------------+---------+-----------+-----------+ + + Bouncer Metrics (bouncer2) since 2024-02-08 10:48:36 +0000 UTC: + +----------------------------+-------------------+ + | Origin | dropped | + | | bytes | packets | + +----------------------------+---------+---------+ + | CAPI (community blocklist) | 2.00k | - | + | lists:somelist | 1.50k | 20 | + +----------------------------+---------+---------+ + | Total | 3.50k | 20 | + +----------------------------+---------+---------+ + EOT +} diff --git a/test/bats/08_metrics_machines.bats b/test/bats/08_metrics_machines.bats new file mode 100644 index 00000000000..3b73839e753 --- /dev/null +++ b/test/bats/08_metrics_machines.bats @@ -0,0 +1,100 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + ./instance-data load + ./instance-crowdsec start +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "lp usage metrics (empty payload)" { + # a registered log processor can send metrics for the lapi and console + TOKEN=$(lp-get-token) + export TOKEN + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + rune -22 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing log processor data"}' +} + +@test "lp usage metrics (bad payload)" { + TOKEN=$(lp-get-token) + export TOKEN + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: + - version: "v1.0" + EOT + ) + + rune -22 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + assert_stderr --partial "error: 422" + rune -0 jq -r '.message' <(output) + assert_output - <<-EOT + validation failure list: + log_processors.0.utc_startup_timestamp in body is required + log_processors.0.datasources in body is required + log_processors.0.hub_items in body is required + EOT +} + +@test "lp usage metrics (full payload)" { + TOKEN=$(lp-get-token) + export TOKEN + + # base payload without any measurement + + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: + - version: "v1.0" + utc_startup_timestamp: 1707399316 + hub_items: {} + feature_flags: + - marshmallows + os: + name: CentOS + version: "8" + metrics: + - name: logs_parsed + value: 5000 + unit: count + labels: {} + items: [] + meta: + window_size_seconds: 600 + utc_now_timestamp: 1707485349 + console_options: + - share_context + datasources: + syslog: 1 + file: 4 + EOT + ) + + rune -0 curl-with-token '/v1/usage-metrics' -X POST --data "$payload" + refute_output +} diff --git a/test/bats/09_context.bats b/test/bats/09_context.bats index ba295451070..71aabc68d29 100644 --- a/test/bats/09_context.bats +++ b/test/bats/09_context.bats @@ -65,6 +65,11 @@ teardown() { assert_stderr --partial "while checking console_context_path: stat $CONTEXT_YAML: no such file or directory" } +@test "csli lapi context delete" { + rune -1 cscli lapi context delete + assert_stderr --partial "command 'delete' has been removed, please manually edit the context file" +} + @test "context file is bad" { echo "bad yaml" > "$CONTEXT_YAML" rune -1 "$CROWDSEC" -t diff --git a/test/bats/09_socket.bats b/test/bats/09_socket.bats new file mode 100644 index 00000000000..f861d8a40dc --- /dev/null +++ b/test/bats/09_socket.bats @@ -0,0 +1,158 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + sockdir=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp -u) + export sockdir + mkdir -p "$sockdir" + socket="$sockdir/crowdsec_api.sock" + export socket + LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') + export LOCAL_API_CREDENTIALS +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load + config_set ".api.server.listen_socket=strenv(socket)" +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "cscli - connects from existing machine with socket" { + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "crowdsec - listen on both socket and TCP" { + ./instance-crowdsec start + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on http://127.0.0.1:8080/" + assert_output --partial "You can successfully interact with Local API (LAPI)" + + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + rune -0 cscli lapi status + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" +} + +@test "cscli - authenticate new machine with socket" { + # verify that if a listen_uri and a socket are set, the socket is used + # by default when creating a local machine. + + rune -0 cscli machines delete "$(cscli machines list -o json | jq -r '.[].machineId')" + + # this one should be using the socket + rune -0 cscli machines add --auto --force + + using=$(config_get "$LOCAL_API_CREDENTIALS" ".url") + + assert [ "$using" = "$socket" ] + + # disable the agent because it counts as a first authentication + config_disable_agent + ./instance-crowdsec start + + # the machine does not have an IP yet + + rune -0 cscli machines list -o json + rune -0 jq -r '.[].ipAddress' <(output) + assert_output null + + # upon first authentication, it's assigned to localhost + + rune -0 cscli lapi status + + rune -0 cscli machines list -o json + rune -0 jq -r '.[].ipAddress' <(output) + assert_output 127.0.0.1 +} + +bouncer_http() { + URI="$1" + curl -fs -H "X-Api-Key: $API_KEY" "http://localhost:8080$URI" +} + +bouncer_socket() { + URI="$1" + curl -fs -H "X-Api-Key: $API_KEY" --unix-socket "$socket" "http://localhost$URI" +} + +@test "lapi - connects from existing bouncer with socket" { + ./instance-crowdsec start + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + # the bouncer does not have an IP yet + + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[].ip_address' <(output) + assert_output "" + + # upon first authentication, it's assigned to localhost + + rune -0 bouncer_socket '/v1/decisions' + assert_output 'null' + refute_stderr + + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[].ip_address' <(output) + assert_output "127.0.0.1" + + # we can still use TCP of course + + rune -0 bouncer_http '/v1/decisions' + assert_output 'null' + refute_stderr +} + +@test "lapi - listen on socket only" { + config_set "del(.api.server.listen_uri)" + + mkdir -p "$sockdir" + + # agent is not able to connect right now + config_disable_agent + ./instance-crowdsec start + + API_KEY=$(cscli bouncers add testbouncer -o raw) + export API_KEY + + # now we can't + + rune -1 cscli lapi status + assert_stderr --partial "connection refused" + + rune -7 bouncer_http '/v1/decisions' + refute_output + refute_stderr + + # here we can + + config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" + + rune -0 cscli lapi status + + rune -0 bouncer_socket '/v1/decisions' + assert_output 'null' + refute_stderr +} diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats index 3f6167ff6f7..f99913dcee5 100644 --- a/test/bats/10_bouncers.bats +++ b/test/bats/10_bouncers.bats @@ -25,7 +25,13 @@ teardown() { @test "there are 0 bouncers" { rune -0 cscli bouncers list -o json - assert_output "[]" + assert_json '[]' + + rune -0 cscli bouncers list -o human + assert_output --partial "Name" + + rune -0 cscli bouncers list -o raw + assert_output --partial 'name' } @test "we can add one bouncer, and delete it" { @@ -33,7 +39,68 @@ teardown() { assert_output --partial "API key for 'ciTestBouncer':" rune -0 cscli bouncers delete ciTestBouncer rune -0 cscli bouncers list -o json - assert_output '[]' + assert_json '[]' +} + +@test "bouncer api-key auth" { + rune -0 cscli bouncers add ciTestBouncer --key "goodkey" + + # connect with good credentials + rune -0 curl-tcp "/v1/decisions" -sS --fail-with-body -H "X-Api-Key: goodkey" + assert_output null + + # connect with bad credentials + rune -22 curl-tcp "/v1/decisions" -sS --fail-with-body -H "X-Api-Key: badkey" + assert_stderr --partial 'error: 403' + assert_json '{message:"access forbidden"}' + + # connect with no credentials + rune -22 curl-tcp "/v1/decisions" -sS --fail-with-body + assert_stderr --partial 'error: 403' + assert_json '{message:"access forbidden"}' +} + +@test "delete non-existent bouncer" { + # this is a fatal error, which is not consistent with "machines delete" + rune -1 cscli bouncers delete something + assert_stderr --partial "unable to delete bouncer: 'something' does not exist" + rune -0 cscli bouncers delete something --ignore-missing + refute_stderr +} + +@test "bouncers delete has autocompletion" { + rune -0 cscli bouncers add foo1 + rune -0 cscli bouncers add foo2 + rune -0 cscli bouncers add bar + rune -0 cscli bouncers add baz + rune -0 cscli __complete bouncers delete 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' +} + +@test "cscli bouncers list" { + export API_KEY=bouncerkey + rune -0 cscli bouncers add ciTestBouncer --key "$API_KEY" + + rune -0 cscli bouncers list -o json + rune -0 jq -c '.[] | [.ip_address,.last_pull,.name]' <(output) + assert_json '["",null,"ciTestBouncer"]' + rune -0 cscli bouncers list -o raw + assert_line 'name,ip,revoked,last_pull,type,version,auth_type' + assert_line 'ciTestBouncer,,validated,,,,api-key' + rune -0 cscli bouncers list -o human + assert_output --regexp 'ciTestBouncer.*api-key.*' + + # the first connection sets last_pull and ip address + rune -0 curl-with-key '/v1/decisions' + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[] | .ip_address' <(output) + assert_output 127.0.0.1 + rune -0 cscli bouncers list -o json + rune -0 jq -r '.[] | .last_pull' <(output) + refute_output null } @test "we can create a bouncer with a known key" { @@ -68,3 +135,12 @@ teardown() { rune -1 cscli bouncers delete ciTestBouncer rune -1 cscli bouncers delete foobarbaz } + +@test "cscli bouncers prune" { + rune -0 cscli bouncers prune + assert_output 'No bouncers to prune.' + rune -0 cscli bouncers add ciTestBouncer + + rune -0 cscli bouncers prune + assert_output 'No bouncers to prune.' +} diff --git a/test/bats/11_bouncers_tls.bats b/test/bats/11_bouncers_tls.bats index 8fb4579259d..554308ae962 100644 --- a/test/bats/11_bouncers_tls.bats +++ b/test/bats/11_bouncers_tls.bats @@ -3,36 +3,116 @@ set -u +# root: root CA +# inter: intermediate CA +# inter_rev: intermediate CA revoked by root (CRL3) +# leaf: valid client cert +# leaf_rev1: client cert revoked by inter (CRL1) +# leaf_rev2: client cert revoked by inter (CRL2) +# leaf_rev3: client cert (indirectly) revoked by root +# +# CRL1: inter revokes leaf_rev1 +# CRL2: inter revokes leaf_rev2 +# CRL3: root revokes inter_rev +# CRL4: root revokes leaf, but is ignored + setup_file() { load "../lib/setup_file.sh" ./instance-data load - tmpdir="${BATS_FILE_TMPDIR}" + tmpdir="$BATS_FILE_TMPDIR" export tmpdir - CFDIR="${BATS_TEST_DIRNAME}/testdata/cfssl" + CFDIR="$BATS_TEST_DIRNAME/testdata/cfssl" export CFDIR - #gen the CA - cfssl gencert --initca "${CFDIR}/ca.json" 2>/dev/null | cfssljson --bare "${tmpdir}/ca" - #gen an intermediate - cfssl gencert --initca "${CFDIR}/intermediate.json" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - #gen server cert for crowdsec with the intermediate - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=server "${CFDIR}/server.json" 2>/dev/null | cfssljson --bare "${tmpdir}/server" - #gen client cert for the bouncer - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer" - #gen client cert for the bouncer with an invalid OU - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer_invalid.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_bad_ou" - #gen client cert for the bouncer directly signed by the CA, it should be refused by crowdsec as uses the intermediate - cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_invalid" - - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/bouncer.json" 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_revoked" - serial="$(openssl x509 -noout -serial -in "${tmpdir}/bouncer_revoked.pem" | cut -d '=' -f2)" - echo "ibase=16; ${serial}" | bc >"${tmpdir}/serials.txt" - cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" - - cat "${tmpdir}/ca.pem" "${tmpdir}/inter.pem" > "${tmpdir}/bundle.pem" + # Root CA + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_root.json" \ + | cfssljson --bare "$tmpdir/root" + + # Intermediate CAs (valid or revoked) + for cert in "inter" "inter_rev"; do + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_intermediate.json" \ + | cfssljson --bare "$tmpdir/$cert" + + cfssl sign -loglevel 2 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile intermediate_ca "$tmpdir/$cert.csr" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Server cert for crowdsec with the intermediate + cfssl gencert -loglevel 2 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=server "$CFDIR/server.json" \ + | cfssljson --bare "$tmpdir/server" + + # Client certs (valid or revoked) + for cert in "leaf" "leaf_rev1" "leaf_rev2"; do + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Client cert (by revoked inter) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter_rev.pem" -ca-key "$tmpdir/inter_rev-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/leaf_rev3" + + # Bad client cert (invalid OU) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer_invalid.json" \ + | cfssljson --bare "$tmpdir/leaf_bad_ou" + + # Bad client cert (directly signed by the CA, it should be refused by crowdsec as it uses the intermediate) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/bouncer.json" \ + | cfssljson --bare "$tmpdir/leaf_invalid" + + truncate -s 0 "$tmpdir/crl.pem" + + # Revoke certs + { + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev1.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev2.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/inter_rev.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + } >> "$tmpdir/crl.pem" + + cat "$tmpdir/root.pem" "$tmpdir/inter.pem" > "$tmpdir/bundle.pem" config_set ' .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | @@ -65,9 +145,14 @@ teardown() { assert_output "[]" } -@test "simulate one bouncer request with a valid cert" { - rune -0 curl -s --cert "${tmpdir}/bouncer.pem" --key "${tmpdir}/bouncer-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 +@test "simulate a bouncer request with a valid cert" { + rune -0 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 assert_output "null" + refute_stderr rune -0 cscli bouncers list -o json rune -0 jq '. | length' <(output) assert_output '1' @@ -77,21 +162,86 @@ teardown() { rune cscli bouncers delete localhost@127.0.0.1 } -@test "simulate one bouncer request with an invalid cert" { - rune curl -s --cert "${tmpdir}/bouncer_invalid.pem" --key "${tmpdir}/bouncer_invalid-key.pem" --cacert "${tmpdir}/ca-key.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 - rune -0 cscli bouncers list -o json - assert_output "[]" +@test "a bouncer authenticated with TLS can send metrics" { + payload=$(yq -o j <<-EOT + remediation_components: [] + log_processors: [] + EOT + ) + + # with mutual authentication there is no api key, so it's detected as RC if user agent != crowdsec + + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/usage-metrics -X POST --data "$payload" + assert_stderr --partial 'error: 400' + assert_json '{message: "Missing remediation component data"}' + + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf.pem" \ + --key "$tmpdir/leaf-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + --user-agent "crowdsec/someversion" \ + https://localhost:8080/v1/usage-metrics -X POST --data "$payload" + assert_stderr --partial 'error: 401' + assert_json '{code:401, message: "cookie token is empty"}' + + rune cscli bouncers delete localhost@127.0.0.1 } -@test "simulate one bouncer request with an invalid OU" { - rune curl -s --cert "${tmpdir}/bouncer_bad_ou.pem" --key "${tmpdir}/bouncer_bad_ou-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 +@test "simulate a bouncer request with an invalid cert" { + rune -77 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf_invalid.pem" \ + --key "$tmpdir/leaf_invalid-key.pem" \ + --cacert "$tmpdir/root-key.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_stderr --partial 'error setting certificate file' rune -0 cscli bouncers list -o json assert_output "[]" } -@test "simulate one bouncer request with a revoked certificate" { - rune -0 curl -i -s --cert "${tmpdir}/bouncer_revoked.pem" --key "${tmpdir}/bouncer_revoked-key.pem" --cacert "${tmpdir}/bundle.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 - assert_output --partial "access forbidden" +@test "simulate a bouncer request with an invalid OU" { + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/leaf_bad_ou.pem" \ + --key "$tmpdir/leaf_bad_ou-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_json '{message: "access forbidden"}' + assert_stderr --partial 'error: 403' rune -0 cscli bouncers list -o json assert_output "[]" } + +@test "simulate a bouncer request with a revoked certificate" { + # we have two certificates revoked by different CRL blocks + # we connect twice to test the cache too + for cert in "leaf_rev1" "leaf_rev2" "leaf_rev1" "leaf_rev2"; do + truncate_log + rune -22 curl --fail-with-body -sS \ + --cert "$tmpdir/$cert.pem" \ + --key "$tmpdir/$cert-key.pem" \ + --cacert "$tmpdir/bundle.pem" \ + https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_log --partial "certificate revoked by CRL" + assert_json '{message: "access forbidden"}' + assert_stderr --partial "error: 403" + rune -0 cscli bouncers list -o json + assert_output "[]" + done +} + +# vvv this test must be last, or it can break the ones that follow + +@test "allowed_ou can't contain an empty string" { + ./instance-crowdsec stop + config_set ' + .common.log_media="stdout" | + .api.server.tls.bouncers_allowed_ou=["bouncer-ou", ""] + ' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "allowed_ou configuration contains invalid empty string" +} + +# ^^^ this test must be last, or it can break the ones that follow diff --git a/test/bats/13_capi_whitelists.bats b/test/bats/13_capi_whitelists.bats index d05a9d93294..ed7ef2ac560 100644 --- a/test/bats/13_capi_whitelists.bats +++ b/test/bats/13_capi_whitelists.bats @@ -31,7 +31,7 @@ teardown() { @test "capi_whitelists: file missing" { rune -0 wait-for \ --err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: no such file or directory" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi_whitelists: error on open" { @@ -40,11 +40,11 @@ teardown() { if is_package_testing; then rune -0 wait-for \ --err "while parsing capi whitelist file .*: empty file" \ - "${CROWDSEC}" + "$CROWDSEC" else rune -0 wait-for \ --err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: permission denied" \ - "${CROWDSEC}" + "$CROWDSEC" fi } @@ -52,28 +52,28 @@ teardown() { echo > "$CAPI_WHITELISTS_YAML" rune -0 wait-for \ --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': empty file" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi_whitelists: empty lists" { echo '{"ips": [], "cidrs": []}' > "$CAPI_WHITELISTS_YAML" rune -0 wait-for \ --err "Starting processing data" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi_whitelists: bad ip" { echo '{"ips": ["blahblah"], "cidrs": []}' > "$CAPI_WHITELISTS_YAML" rune -0 wait-for \ --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': invalid IP address: blahblah" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi_whitelists: bad cidr" { echo '{"ips": [], "cidrs": ["blahblah"]}' > "$CAPI_WHITELISTS_YAML" rune -0 wait-for \ --err "while parsing capi whitelist file '$CAPI_WHITELISTS_YAML': invalid CIDR address: blahblah" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "capi_whitelists: file with ip and cidr values" { diff --git a/test/bats/20_hub.bats b/test/bats/20_hub.bats index 18e3770bcd0..b8fa1e9efca 100644 --- a/test/bats/20_hub.bats +++ b/test/bats/20_hub.bats @@ -76,7 +76,7 @@ teardown() { assert_stderr --partial "invalid hub item appsec-rules:crowdsecurity/vpatch-laravel-debug-mode: latest version missing from index" rune -1 cscli appsec-rules install crowdsecurity/vpatch-laravel-debug-mode --force - assert_stderr --partial "error while installing 'crowdsecurity/vpatch-laravel-debug-mode': while downloading crowdsecurity/vpatch-laravel-debug-mode: latest hash missing from index" + assert_stderr --partial "error while installing 'crowdsecurity/vpatch-laravel-debug-mode': latest hash missing from index. The index file is invalid, please run 'cscli hub update' and try again" } @test "missing reference in hub index" { @@ -125,13 +125,19 @@ teardown() { assert_stderr --partial "Upgraded 0 contexts" assert_stderr --partial "Upgrading collections" assert_stderr --partial "Upgraded 0 collections" + assert_stderr --partial "Upgrading appsec-configs" + assert_stderr --partial "Upgraded 0 appsec-configs" + assert_stderr --partial "Upgrading appsec-rules" + assert_stderr --partial "Upgraded 0 appsec-rules" + assert_stderr --partial "Upgrading collections" + assert_stderr --partial "Upgraded 0 collections" rune -0 cscli parsers install crowdsecurity/syslog-logs rune -0 cscli hub upgrade assert_stderr --partial "crowdsecurity/syslog-logs: up-to-date" rune -0 cscli hub upgrade --force - assert_stderr --partial "crowdsecurity/syslog-logs: overwrite" + assert_stderr --partial "crowdsecurity/syslog-logs: up-to-date" assert_stderr --partial "crowdsecurity/syslog-logs: updated" assert_stderr --partial "Upgraded 1 parsers" # this is used by the cron script to know if the hub was updated diff --git a/test/bats/20_hub_collections.bats b/test/bats/20_hub_collections.bats index 5e5b43a9e4f..6822339ae40 100644 --- a/test/bats/20_hub_collections.bats +++ b/test/bats/20_hub_collections.bats @@ -177,10 +177,9 @@ teardown() { echo "dirty" >"$CONFIG_DIR/collections/sshd.yaml" rune -1 cscli collections install crowdsecurity/sshd - assert_stderr --partial "error while installing 'crowdsecurity/sshd': while enabling crowdsecurity/sshd: crowdsecurity/sshd is tainted, won't enable unless --force" + assert_stderr --partial "error while installing 'crowdsecurity/sshd': while enabling crowdsecurity/sshd: crowdsecurity/sshd is tainted, won't overwrite unless --force" rune -0 cscli collections install crowdsecurity/sshd --force - assert_stderr --partial "crowdsecurity/sshd: overwrite" assert_stderr --partial "Enabled crowdsecurity/sshd" } diff --git a/test/bats/20_hub_collections_dep.bats b/test/bats/20_hub_collections_dep.bats index c3df948a353..673b812dc0d 100644 --- a/test/bats/20_hub_collections_dep.bats +++ b/test/bats/20_hub_collections_dep.bats @@ -121,6 +121,6 @@ teardown() { rune -1 cscli hub list assert_stderr --partial "circular dependency detected" - rune -1 wait-for "${CROWDSEC}" + rune -1 wait-for "$CROWDSEC" assert_stderr --partial "circular dependency detected" } diff --git a/test/bats/20_hub_items.bats b/test/bats/20_hub_items.bats index 72e09dfa268..4b390c90ed4 100644 --- a/test/bats/20_hub_items.bats +++ b/test/bats/20_hub_items.bats @@ -46,7 +46,7 @@ teardown() { '. * {collections:{"crowdsecurity/sshd":{"versions":{"1.2":{"digest":$DIGEST, "deprecated": false}, "1.10": {"digest":$DIGEST, "deprecated": false}}}}}' \ ) echo "$new_hub" >"$INDEX_PATH" - + rune -0 cscli collections install crowdsecurity/sshd truncate -s 0 "$CONFIG_DIR/collections/sshd.yaml" @@ -78,12 +78,12 @@ teardown() { '. * {collections:{"crowdsecurity/sshd":{"versions":{"1.2.3.4":{"digest":"foo", "deprecated": false}}}}}' \ ) echo "$new_hub" >"$INDEX_PATH" - + rune -0 cscli collections install crowdsecurity/sshd rune -1 cscli collections inspect crowdsecurity/sshd --no-metrics -o json # XXX: we are on the verbose side here... rune -0 jq -r ".msg" <(stderr) - assert_output --regexp "failed to read Hub index: failed to sync items: failed to scan .*: while syncing collections sshd.yaml: 1.2.3.4: Invalid Semantic Version. Run 'sudo cscli hub update' to download the index again" + assert_output --regexp "failed to read Hub index: failed to sync hub items: failed to scan .*: while syncing collections sshd.yaml: 1.2.3.4: Invalid Semantic Version. Run 'sudo cscli hub update' to download the index again" } @test "removing or purging an item already removed by hand" { @@ -176,7 +176,7 @@ teardown() { rune -0 mkdir -p "$CONFIG_DIR/collections" rune -0 ln -s /this/does/not/exist.yaml "$CONFIG_DIR/collections/foobar.yaml" rune -0 cscli hub list - assert_stderr --partial "link target does not exist: $CONFIG_DIR/collections/foobar.yaml -> /this/does/not/exist.yaml" + assert_stderr --partial "Ignoring file $CONFIG_DIR/collections/foobar.yaml: lstat /this/does/not/exist.yaml: no such file or directory" rune -0 cscli hub list -o json rune -0 jq '.collections' <(output) assert_json '[]' @@ -193,3 +193,90 @@ teardown() { rune -0 jq -c '.tainted' <(output) assert_output 'false' } + +@test "don't traverse hidden directories (starting with a dot)" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/bar.yaml" + rune -0 cscli hub list --trace + assert_stderr --partial "skipping hidden directory $CONFIG_DIR/scenarios/.foo" +} + +@test "allow symlink to target inside a hidden directory" { + # k8s config maps use hidden directories and links when mounted + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + + # ignored + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # real file + rune -0 touch "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 + + rune -0 rm "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # link to ignored is not ignored, and the name comes from the link + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["myfoo.yaml"]' +} + +@test "item files can be links to links" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{.foo,.bar} + + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/.bar/hidden.yaml" + + # link to a danling link + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: lstat $CONFIG_DIR/scenarios/.foo/hidden.yaml: no such file or directory" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # detect link loops + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: too many levels of symbolic links" + + rune -0 rm "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 +} + +@test "item files can be in a subdirectory" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/sub/sub2/sub3" + rune -0 touch "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + # subdir name is now part of the item name + rune -0 cscli scenarios inspect sub/imlocal.yaml -o json + rune -0 jq -e '[.tainted,.local==false,true]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/sub2/sub3/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) +} + +@test "same file name for local items in different subdirectories" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{foo,bar} + rune -0 touch "$CONFIG_DIR/scenarios/foo/local.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/bar/local.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["bar/local.yaml","foo/local.yaml"]' +} diff --git a/test/bats/20_hub_parsers.bats b/test/bats/20_hub_parsers.bats index 71a1f933a92..791b1a2177f 100644 --- a/test/bats/20_hub_parsers.bats +++ b/test/bats/20_hub_parsers.bats @@ -177,10 +177,9 @@ teardown() { echo "dirty" >"$CONFIG_DIR/parsers/s02-enrich/whitelists.yaml" rune -1 cscli parsers install crowdsecurity/whitelists - assert_stderr --partial "error while installing 'crowdsecurity/whitelists': while enabling crowdsecurity/whitelists: crowdsecurity/whitelists is tainted, won't enable unless --force" + assert_stderr --partial "error while installing 'crowdsecurity/whitelists': while enabling crowdsecurity/whitelists: crowdsecurity/whitelists is tainted, won't overwrite unless --force" rune -0 cscli parsers install crowdsecurity/whitelists --force - assert_stderr --partial "crowdsecurity/whitelists: overwrite" assert_stderr --partial "Enabled crowdsecurity/whitelists" } diff --git a/test/bats/20_hub_postoverflows.bats b/test/bats/20_hub_postoverflows.bats index de4b1e8a59e..37337b08caa 100644 --- a/test/bats/20_hub_postoverflows.bats +++ b/test/bats/20_hub_postoverflows.bats @@ -177,10 +177,9 @@ teardown() { echo "dirty" >"$CONFIG_DIR/postoverflows/s00-enrich/rdns.yaml" rune -1 cscli postoverflows install crowdsecurity/rdns - assert_stderr --partial "error while installing 'crowdsecurity/rdns': while enabling crowdsecurity/rdns: crowdsecurity/rdns is tainted, won't enable unless --force" + assert_stderr --partial "error while installing 'crowdsecurity/rdns': while enabling crowdsecurity/rdns: crowdsecurity/rdns is tainted, won't overwrite unless --force" rune -0 cscli postoverflows install crowdsecurity/rdns --force - assert_stderr --partial "crowdsecurity/rdns: overwrite" assert_stderr --partial "Enabled crowdsecurity/rdns" } diff --git a/test/bats/20_hub_scenarios.bats b/test/bats/20_hub_scenarios.bats index 9c441057aa2..3ab3d944c93 100644 --- a/test/bats/20_hub_scenarios.bats +++ b/test/bats/20_hub_scenarios.bats @@ -96,7 +96,7 @@ teardown() { # non-existent rune -1 cscli scenario install foo/bar assert_stderr --partial "can't find 'foo/bar' in scenarios" - + # not installed rune -0 cscli scenarios list crowdsecurity/ssh-bf assert_output --regexp 'crowdsecurity/ssh-bf.*disabled' @@ -178,10 +178,9 @@ teardown() { echo "dirty" >"$CONFIG_DIR/scenarios/ssh-bf.yaml" rune -1 cscli scenarios install crowdsecurity/ssh-bf - assert_stderr --partial "error while installing 'crowdsecurity/ssh-bf': while enabling crowdsecurity/ssh-bf: crowdsecurity/ssh-bf is tainted, won't enable unless --force" + assert_stderr --partial "error while installing 'crowdsecurity/ssh-bf': while enabling crowdsecurity/ssh-bf: crowdsecurity/ssh-bf is tainted, won't overwrite unless --force" rune -0 cscli scenarios install crowdsecurity/ssh-bf --force - assert_stderr --partial "crowdsecurity/ssh-bf: overwrite" assert_stderr --partial "Enabled crowdsecurity/ssh-bf" } diff --git a/test/bats/30_machines.bats b/test/bats/30_machines.bats index c7a72c334b1..d4cce67d0b0 100644 --- a/test/bats/30_machines.bats +++ b/test/bats/30_machines.bats @@ -34,13 +34,18 @@ teardown() { rune -0 jq -r '.msg' <(stderr) assert_output --partial 'already exists: please remove it, use "--force" or specify a different file with "-f"' rune -0 cscli machines add local -a --force - assert_output --partial "Machine 'local' successfully added to the local API." + assert_stderr --partial "Machine 'local' successfully added to the local API." +} + +@test "passwords have a size limit" { + rune -1 cscli machines add local --password "$(printf '%73s' '' | tr ' ' x)" + assert_stderr --partial "password too long (max 72 characters)" } @test "add a new machine and delete it" { rune -0 cscli machines add -a -f /dev/null CiTestMachine -o human - assert_output --partial "Machine 'CiTestMachine' successfully added to the local API" - assert_output --partial "API credentials written to '/dev/null'" + assert_stderr --partial "Machine 'CiTestMachine' successfully added to the local API" + assert_stderr --partial "API credentials written to '/dev/null'" # we now have two machines rune -0 cscli machines list -o json @@ -57,6 +62,38 @@ teardown() { assert_output 1 } +@test "delete non-existent machine" { + # this is not a fatal error, won't halt a script with -e + rune -0 cscli machines delete something + assert_stderr --partial "unable to delete machine: 'something' does not exist" + rune -0 cscli machines delete something --ignore-missing + refute_stderr +} + +@test "machines [delete|inspect] has autocompletion" { + rune -0 cscli machines add -a -f /dev/null foo1 + rune -0 cscli machines add -a -f /dev/null foo2 + rune -0 cscli machines add -a -f /dev/null bar + rune -0 cscli machines add -a -f /dev/null baz + rune -0 cscli __complete machines delete 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' + rune -0 cscli __complete machines inspect 'foo' + assert_line --index 0 'foo1' + assert_line --index 1 'foo2' + refute_line 'bar' + refute_line 'baz' +} + +@test "heartbeat is initially null" { + rune -0 cscli machines add foo --auto --file /dev/null + rune -0 cscli machines list -o json + rune -0 yq '.[] | select(.machineId == "foo") | .last_heartbeat' <(output) + assert_output null +} + @test "register, validate and then remove a machine" { rune -0 cscli lapi register --machine CiTestMachineRegister -f /dev/null -o human assert_stderr --partial "Successfully registered to Local API (LAPI)" @@ -85,3 +122,20 @@ teardown() { rune -0 jq '. | length' <(output) assert_output 1 } + +@test "cscli machines prune" { + rune -0 cscli metrics + + # if the fixture has been created some time ago, + # the machines may be old enough to trigger a user prompt. + # make sure the prune duration is high enough. + rune -0 cscli machines prune --duration 1000000h + assert_output 'No machines to prune.' + + rune -0 cscli machines list -o json + rune -0 jq -r '.[-1].machineId' <(output) + rune -0 cscli machines delete "$output" + + rune -0 cscli machines prune + assert_output 'No machines to prune.' +} diff --git a/test/bats/30_machines_tls.bats b/test/bats/30_machines_tls.bats index 535435336ba..ef02d1b57c3 100644 --- a/test/bats/30_machines_tls.bats +++ b/test/bats/30_machines_tls.bats @@ -3,39 +3,119 @@ set -u +# root: root CA +# inter: intermediate CA +# inter_rev: intermediate CA revoked by root (CRL3) +# leaf: valid client cert +# leaf_rev1: client cert revoked by inter (CRL1) +# leaf_rev2: client cert revoked by inter (CRL2) +# leaf_rev3: client cert (indirectly) revoked by root +# +# CRL1: inter revokes leaf_rev1 +# CRL2: inter revokes leaf_rev2 +# CRL3: root revokes inter_rev +# CRL4: root revokes leaf, but is ignored + setup_file() { load "../lib/setup_file.sh" ./instance-data load - CONFIG_DIR=$(dirname "${CONFIG_YAML}") + CONFIG_DIR=$(dirname "$CONFIG_YAML") export CONFIG_DIR - tmpdir="${BATS_FILE_TMPDIR}" + tmpdir="$BATS_FILE_TMPDIR" export tmpdir - CFDIR="${BATS_TEST_DIRNAME}/testdata/cfssl" + CFDIR="$BATS_TEST_DIRNAME/testdata/cfssl" export CFDIR - #gen the CA - cfssl gencert --initca "${CFDIR}/ca.json" 2>/dev/null | cfssljson --bare "${tmpdir}/ca" - #gen an intermediate - cfssl gencert --initca "${CFDIR}/intermediate.json" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" - #gen server cert for crowdsec with the intermediate - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=server "${CFDIR}/server.json" 2>/dev/null | cfssljson --bare "${tmpdir}/server" - #gen client cert for the agent - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent" - #gen client cert for the agent with an invalid OU - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent_invalid.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_bad_ou" - #gen client cert for the agent directly signed by the CA, it should be refused by crowdsec as uses the intermediate - cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_invalid" - - cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config "${CFDIR}/profiles.json" -profile=client "${CFDIR}/agent.json" 2>/dev/null | cfssljson --bare "${tmpdir}/agent_revoked" - serial="$(openssl x509 -noout -serial -in "${tmpdir}/agent_revoked.pem" | cut -d '=' -f2)" - echo "ibase=16; ${serial}" | bc >"${tmpdir}/serials.txt" - cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" - - cat "${tmpdir}/ca.pem" "${tmpdir}/inter.pem" > "${tmpdir}/bundle.pem" + # Root CA + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_root.json" \ + | cfssljson --bare "$tmpdir/root" + + # Intermediate CAs (valid or revoked) + for cert in "inter" "inter_rev"; do + cfssl gencert -loglevel 2 \ + --initca "$CFDIR/ca_intermediate.json" \ + | cfssljson --bare "$tmpdir/$cert" + + cfssl sign -loglevel 2 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile intermediate_ca "$tmpdir/$cert.csr" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Server cert for crowdsec with the intermediate + cfssl gencert -loglevel 2 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=server "$CFDIR/server.json" \ + | cfssljson --bare "$tmpdir/server" + + # Client certs (valid or revoked) + for cert in "leaf" "leaf_rev1" "leaf_rev2"; do + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/$cert" + done + + # Client cert (by revoked inter) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter_rev.pem" -ca-key "$tmpdir/inter_rev-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/leaf_rev3" + + # Bad client cert (invalid OU) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/inter.pem" -ca-key "$tmpdir/inter-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent_invalid.json" \ + | cfssljson --bare "$tmpdir/leaf_bad_ou" + + # Bad client cert (directly signed by the CA, it should be refused by crowdsec as it uses the intermediate) + cfssl gencert -loglevel 3 \ + -ca "$tmpdir/root.pem" -ca-key "$tmpdir/root-key.pem" \ + -config "$CFDIR/profiles.json" -profile=client \ + "$CFDIR/agent.json" \ + | cfssljson --bare "$tmpdir/leaf_invalid" + + truncate -s 0 "$tmpdir/crl.pem" + + # Revoke certs + { + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev1.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf_rev2.pem") \ + "$tmpdir/inter.pem" \ + "$tmpdir/inter-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/inter_rev.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + + echo '-----BEGIN X509 CRL-----' + cfssl gencrl \ + <(cert_serial_number "$tmpdir/leaf.pem") \ + "$tmpdir/root.pem" \ + "$tmpdir/root-key.pem" + echo '-----END X509 CRL-----' + } >> "$tmpdir/crl.pem" + + cat "$tmpdir/root.pem" "$tmpdir/inter.pem" > "$tmpdir/bundle.pem" config_set ' .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | @@ -48,7 +128,7 @@ setup_file() { # remove all machines for machine in $(cscli machines list -o json | jq -r '.[].machineId'); do - cscli machines delete "${machine}" >/dev/null 2>&1 + cscli machines delete "$machine" >/dev/null 2>&1 done config_disable_agent @@ -80,7 +160,7 @@ teardown() { rune -0 wait-for \ --err "missing TLS key file" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "missing cert_file" { @@ -88,64 +168,131 @@ teardown() { rune -0 wait-for \ --err "missing TLS cert file" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "invalid OU for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_bad_ou-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_bad_ou.pem" | + .key_path=strenv(tmpdir) + "/leaf_bad_ou-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf_bad_ou.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start rune -0 cscli machines list -o json assert_output '[]' } @test "we have exactly one machine registered with TLS" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent-key.pem" | - .cert_path=strenv(tmpdir) + "/agent.pem" | + .key_path=strenv(tmpdir) + "/leaf-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start rune -0 cscli lapi status + # second connection, test the tls cache + rune -0 cscli lapi status rune -0 cscli machines list -o json rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]' - cscli machines delete localhost@127.0.0.1 + rune -0 cscli machines delete localhost@127.0.0.1 +} + +@test "a machine can still connect with a unix socket, no TLS" { + sock=$(config_get '.api.server.listen_socket') + export sock + + # an agent is a machine too + config_disable_agent + ./instance-crowdsec start + + rune -0 cscli machines add with-socket --auto --force + rune -0 cscli lapi status + + rune -0 cscli machines list -o json + rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) + assert_output '[1,"with-socket",true,"127.0.0.1","password"]' + + # TLS cannot be used with a unix socket + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + .ca_cert_path=strenv(tmpdir) + "/bundle.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + del(.ca_cert_path) | + .key_path=strenv(tmpdir) + "/leaf-key.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + del(.key_path) | + .cert_path=strenv(tmpdir) + "/leaf.pem" + ' + + rune -1 cscli lapi status + assert_stderr --partial "loading api client: cannot use TLS with a unix socket" + + rune -0 cscli machines delete with-socket } @test "invalid cert for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' + config_set "$CONFIG_DIR/local_api_credentials.yaml" ' .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_invalid-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_invalid.pem" | + .key_path=strenv(tmpdir) + "/leaf_invalid-key.pem" | + .cert_path=strenv(tmpdir) + "/leaf_invalid.pem" | .url="https://127.0.0.1:8080" ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' ./instance-crowdsec start + rune -1 cscli lapi status rune -0 cscli machines list -o json assert_output '[]' } @test "revoked cert for agent" { - config_set "${CONFIG_DIR}/local_api_credentials.yaml" ' - .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | - .key_path=strenv(tmpdir) + "/agent_revoked-key.pem" | - .cert_path=strenv(tmpdir) + "/agent_revoked.pem" | - .url="https://127.0.0.1:8080" - ' + # we have two certificates revoked by different CRL blocks + # we connect twice to test the cache too + for cert in "leaf_rev1" "leaf_rev2" "leaf_rev1" "leaf_rev2"; do + truncate_log + cert="$cert" config_set "$CONFIG_DIR/local_api_credentials.yaml" ' + .ca_cert_path=strenv(tmpdir) + "/bundle.pem" | + .key_path=strenv(tmpdir) + "/" + strenv(cert) + "-key.pem" | + .cert_path=strenv(tmpdir) + "/" + strenv(cert) + ".pem" | + .url="https://127.0.0.1:8080" + ' - config_set "${CONFIG_DIR}/local_api_credentials.yaml" 'del(.login,.password)' - ./instance-crowdsec start - rune -0 cscli machines list -o json - assert_output '[]' + config_set "$CONFIG_DIR/local_api_credentials.yaml" 'del(.login,.password)' + ./instance-crowdsec start + rune -1 cscli lapi status + assert_log --partial "certificate revoked by CRL" + rune -0 cscli machines list -o json + assert_output '[]' + ./instance-crowdsec stop + done } + +# vvv this test must be last, or it can break the ones that follow + +@test "allowed_ou can't contain an empty string" { + config_set ' + .common.log_media="stdout" | + .api.server.tls.agents_allowed_ou=["agent-ou", ""] + ' + rune -1 wait-for "$CROWDSEC" + assert_stderr --partial "allowed_ou configuration contains invalid empty string" +} + +# ^^^ this test must be last, or it can break the ones that follow diff --git a/test/bats/40_cold-logs.bats b/test/bats/40_cold-logs.bats index 36220375b87..070a9eac5f1 100644 --- a/test/bats/40_cold-logs.bats +++ b/test/bats/40_cold-logs.bats @@ -14,9 +14,9 @@ setup_file() { # we reset config and data, and only run the daemon once for all the tests in this file ./instance-data load - cscli collections install crowdsecurity/sshd --error - cscli parsers install crowdsecurity/syslog-logs --error - cscli parsers install crowdsecurity/dateparse-enrich --error + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null ./instance-crowdsec start } @@ -32,14 +32,14 @@ setup() { #---------- @test "-type and -dsn are required together" { - rune -1 "${CROWDSEC}" -no-api -type syslog + rune -1 "$CROWDSEC" -no-api -type syslog assert_stderr --partial "-type requires a -dsn argument" - rune -1 "${CROWDSEC}" -no-api -dsn file:///dev/fd/0 + rune -1 "$CROWDSEC" -no-api -dsn file:///dev/fd/0 assert_stderr --partial "-dsn requires a -type argument" } @test "the one-shot mode works" { - rune -0 "${CROWDSEC}" -dsn file://<(fake_log) -type syslog -no-api + rune -0 "$CROWDSEC" -dsn file://<(fake_log) -type syslog -no-api refute_output assert_stderr --partial "single file mode : log_media=stdout daemonize=false" assert_stderr --regexp "Adding file .* to filelist" diff --git a/test/bats/40_live-ban.bats b/test/bats/40_live-ban.bats index c6b8ddf1563..fb5fd1fd435 100644 --- a/test/bats/40_live-ban.bats +++ b/test/bats/40_live-ban.bats @@ -14,10 +14,9 @@ setup_file() { # we reset config and data, but run the daemon only in the tests that need it ./instance-data load - cscli collections install crowdsecurity/sshd --error - cscli parsers install crowdsecurity/syslog-logs --error - cscli parsers install crowdsecurity/dateparse-enrich --error - + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null } teardown_file() { @@ -35,16 +34,29 @@ teardown() { #---------- @test "1.1.1.172 has been banned" { - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') - echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" + echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"$ACQUIS_YAML" ./instance-crowdsec start - fake_log >>"${tmpfile}" - sleep 2 - rm -f -- "${tmpfile}" - rune -0 cscli decisions list -o json - rune -0 jq -r '.[].decisions[0].value' <(output) - assert_output '1.1.1.172' + + sleep 0.2 + + fake_log >>"$tmpfile" + + sleep 0.2 + + rm -f -- "$tmpfile" + + found=0 + # this may take some time in CI + for _ in $(seq 1 10); do + if cscli decisions list -o json | jq -r '.[].decisions[0].value' | grep -q '1.1.1.172'; then + found=1 + break + fi + sleep 0.2 + done + assert_equal 1 "$found" } diff --git a/test/bats/50_simulation.bats b/test/bats/50_simulation.bats index 0d29d6bfd52..bffa50cbccc 100644 --- a/test/bats/50_simulation.bats +++ b/test/bats/50_simulation.bats @@ -13,9 +13,9 @@ setup_file() { load "../lib/setup_file.sh" ./instance-data load - cscli collections install crowdsecurity/sshd --error - cscli parsers install crowdsecurity/syslog-logs --error - cscli parsers install crowdsecurity/dateparse-enrich --error + cscli collections install crowdsecurity/sshd --error >/dev/null + cscli parsers install crowdsecurity/syslog-logs --error >/dev/null + cscli parsers install crowdsecurity/dateparse-enrich --error >/dev/null ./instance-crowdsec start } @@ -33,7 +33,7 @@ setup() { @test "we have one decision" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq '. | length' <(output) assert_output 1 @@ -41,7 +41,7 @@ setup() { @test "1.1.1.174 has been banned (exact)" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq -r '.[].decisions[0].value' <(output) assert_output '1.1.1.174' @@ -49,7 +49,7 @@ setup() { @test "decision has simulated == false (exact)" { rune -0 cscli simulation disable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list -o json rune -0 jq '.[].decisions[0].simulated' <(output) assert_output 'false' @@ -57,7 +57,20 @@ setup() { @test "simulated scenario, listing non-simulated: expect no decision" { rune -0 cscli simulation enable crowdsecurity/ssh-bf - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api + rune -0 cscli decisions list --no-simu -o json + assert_json '[]' +} + +@test "simulated local scenario: expect no decision" { + CONFIG_DIR=$(dirname "$CONFIG_YAML") + HUB_DIR=$(config_get '.config_paths.hub_dir') + rune -0 mkdir -p "$CONFIG_DIR"/scenarios + # replace an installed scenario with a local version + rune -0 cp -r "$HUB_DIR"/scenarios/crowdsecurity/ssh-bf.yaml "$CONFIG_DIR"/scenarios/ssh-bf2.yaml + rune -0 cscli scenarios remove crowdsecurity/ssh-bf --force --purge + rune -0 cscli simulation enable crowdsecurity/ssh-bf + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list --no-simu -o json assert_json '[]' } @@ -65,7 +78,7 @@ setup() { @test "global simulation, listing non-simulated: expect no decision" { rune -0 cscli simulation disable crowdsecurity/ssh-bf rune -0 cscli simulation enable --global - fake_log | "${CROWDSEC}" -dsn file:///dev/fd/0 -type syslog -no-api + fake_log | "$CROWDSEC" -dsn file:///dev/fd/0 -type syslog -no-api rune -0 cscli decisions list --no-simu -o json assert_json '[]' } diff --git a/test/bats/70_plugin_http.bats b/test/bats/70_plugin_http.bats index a8b860aab83..462fc7c9406 100644 --- a/test/bats/70_plugin_http.bats +++ b/test/bats/70_plugin_http.bats @@ -15,7 +15,7 @@ setup_file() { export MOCK_URL PLUGIN_DIR=$(config_get '.config_paths.plugin_dir') # could have a trailing slash - PLUGIN_DIR=$(realpath "${PLUGIN_DIR}") + PLUGIN_DIR=$(realpath "$PLUGIN_DIR") export PLUGIN_DIR # https://mikefarah.gitbook.io/yq/operators/env-variable-operators @@ -35,10 +35,10 @@ setup_file() { .plugin_config.group="" ' - rm -f -- "${MOCK_OUT}" + rm -f -- "$MOCK_OUT" ./instance-crowdsec start - ./instance-mock-http start "${MOCK_PORT}" + ./instance-mock-http start "$MOCK_PORT" } teardown_file() { @@ -63,24 +63,24 @@ setup() { } @test "expected 1 log line from http server" { - rune -0 wc -l <"${MOCK_OUT}" + rune -0 wc -l <"$MOCK_OUT" # wc can pad with spaces on some platforms rune -0 tr -d ' ' < <(output) assert_output 1 } @test "expected to receive 2 alerts in the request body from plugin" { - rune -0 jq -r '.request_body' <"${MOCK_OUT}" + rune -0 jq -r '.request_body' <"$MOCK_OUT" rune -0 jq -r 'length' <(output) assert_output 2 } @test "expected to receive IP 1.2.3.4 as value of first decision" { - rune -0 jq -r '.request_body[0].decisions[0].value' <"${MOCK_OUT}" + rune -0 jq -r '.request_body[0].decisions[0].value' <"$MOCK_OUT" assert_output 1.2.3.4 } @test "expected to receive IP 1.2.3.5 as value of second decision" { - rune -0 jq -r '.request_body[1].decisions[0].value' <"${MOCK_OUT}" + rune -0 jq -r '.request_body[1].decisions[0].value' <"$MOCK_OUT" assert_output 1.2.3.5 } diff --git a/test/bats/71_plugin_dummy.bats b/test/bats/71_plugin_dummy.bats index 95b64fea070..c242d7ec4bc 100644 --- a/test/bats/71_plugin_dummy.bats +++ b/test/bats/71_plugin_dummy.bats @@ -9,15 +9,15 @@ setup_file() { ./instance-data load - tempfile=$(TMPDIR="${BATS_FILE_TMPDIR}" mktemp) + tempfile=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) export tempfile - tempfile2=$(TMPDIR="${BATS_FILE_TMPDIR}" mktemp) + tempfile2=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) export tempfile2 DUMMY_YAML="$(config_get '.config_paths.notification_dir')/dummy.yaml" - config_set "${DUMMY_YAML}" ' + config_set "$DUMMY_YAML" ' .group_wait="5s" | .group_threshold=2 | .output_file=strenv(tempfile) | @@ -67,12 +67,12 @@ setup() { } @test "expected 1 notification" { - rune -0 cat "${tempfile}" + rune -0 cat "$tempfile" assert_output --partial 1.2.3.4 assert_output --partial 1.2.3.5 } @test "second notification works too" { - rune -0 cat "${tempfile2}" + rune -0 cat "$tempfile2" assert_output --partial secondfile } diff --git a/test/bats/72_plugin_badconfig.bats b/test/bats/72_plugin_badconfig.bats index c9a69b9fcb0..7be16c6cf8e 100644 --- a/test/bats/72_plugin_badconfig.bats +++ b/test/bats/72_plugin_badconfig.bats @@ -8,7 +8,7 @@ setup_file() { PLUGIN_DIR=$(config_get '.config_paths.plugin_dir') # could have a trailing slash - PLUGIN_DIR=$(realpath "${PLUGIN_DIR}") + PLUGIN_DIR=$(realpath "$PLUGIN_DIR") export PLUGIN_DIR PROFILES_PATH=$(config_get '.api.server.profiles_path') @@ -26,50 +26,50 @@ setup() { teardown() { ./instance-crowdsec stop - rm -f "${PLUGIN_DIR}"/badname - chmod go-w "${PLUGIN_DIR}"/notification-http || true + rm -f "$PLUGIN_DIR"/badname + chmod go-w "$PLUGIN_DIR"/notification-http || true } #---------- @test "misconfigured plugin, only user is empty" { config_set '.plugin_config.user="" | .plugin_config.group="nogroup"' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "misconfigured plugin, only group is empty" { config_set '(.plugin_config.user="nobody") | (.plugin_config.group="")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: both plugin user and group must be set" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "misconfigured plugin, user does not exist" { config_set '(.plugin_config.user="userdoesnotexist") | (.plugin_config.group="groupdoesnotexist")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: user: unknown user userdoesnotexist" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "misconfigured plugin, group does not exist" { config_set '(.plugin_config.user=strenv(USER)) | (.plugin_config.group="groupdoesnotexist")' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: while getting process attributes: group: unknown group groupdoesnotexist" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "bad plugin name" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - cp "${PLUGIN_DIR}"/notification-http "${PLUGIN_DIR}"/badname + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + cp "$PLUGIN_DIR"/notification-http "$PLUGIN_DIR"/badname rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: plugin name ${PLUGIN_DIR}/badname is invalid. Name should be like {type-name}" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "duplicate notification config" { @@ -77,58 +77,58 @@ teardown() { # email_default has two configurations rune -0 yq -i '.name="email_default"' "$CONFIG_DIR/notifications/http.yaml" # enable a notification, otherwise plugins are ignored - config_set "${PROFILES_PATH}" '.notifications=["slack_default"]' + config_set "$PROFILES_PATH" '.notifications=["slack_default"]' # the slack plugin may fail or not, but we just need the logs config_set '.common.log_media="stdout"' rune wait-for \ --err "notification 'email_default' is defined multiple times" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "bad plugin permission (group writable)" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - chmod g+w "${PLUGIN_DIR}"/notification-http + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + chmod g+w "$PLUGIN_DIR"/notification-http rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is group writable, group writable plugins are invalid" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "bad plugin permission (world writable)" { - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' - chmod o+w "${PLUGIN_DIR}"/notification-http + config_set "$PROFILES_PATH" '.notifications=["http_default"]' + chmod o+w "$PLUGIN_DIR"/notification-http rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin: plugin at ${PLUGIN_DIR}/notification-http is world writable, world writable plugins are invalid" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "config.yaml: missing .plugin_config section" { config_set 'del(.plugin_config)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: plugins are enabled, but the plugin_config section is missing in the configuration" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "config.yaml: missing config_paths.notification_dir" { config_set 'del(.config_paths.notification_dir)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: plugins are enabled, but config_paths.notification_dir is not defined" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "config.yaml: missing config_paths.plugin_dir" { config_set 'del(.config_paths.plugin_dir)' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: plugins are enabled, but config_paths.plugin_dir is not defined" \ - "${CROWDSEC}" + "$CROWDSEC" } @test "unable to run plugin broker: while reading plugin config" { config_set '.config_paths.notification_dir="/this/path/does/not/exist"' - config_set "${PROFILES_PATH}" '.notifications=["http_default"]' + config_set "$PROFILES_PATH" '.notifications=["http_default"]' rune -0 wait-for \ --err "api server init: unable to run plugin broker: while loading plugin config: open /this/path/does/not/exist: no such file or directory" \ - "${CROWDSEC}" + "$CROWDSEC" } diff --git a/test/bats/73_plugin_formatting.bats b/test/bats/73_plugin_formatting.bats index 153193fb18f..9ed64837403 100644 --- a/test/bats/73_plugin_formatting.bats +++ b/test/bats/73_plugin_formatting.bats @@ -9,7 +9,7 @@ setup_file() { ./instance-data load - tempfile=$(TMPDIR="${BATS_FILE_TMPDIR}" mktemp) + tempfile=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp) export tempfile DUMMY_YAML="$(config_get '.config_paths.notification_dir')/dummy.yaml" @@ -17,7 +17,7 @@ setup_file() { # we test the template that is suggested in the email notification # the $alert is not a shell variable # shellcheck disable=SC2016 - config_set "${DUMMY_YAML}" ' + config_set "$DUMMY_YAML" ' .group_wait="5s" | .group_threshold=2 | .output_file=strenv(tempfile) | @@ -58,7 +58,7 @@ setup() { } @test "expected 1 notification" { - rune -0 cat "${tempfile}" + rune -0 cat "$tempfile" assert_output - <<-EOT

1.2.3.4 will get ban for next 30s for triggering manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX' on machine githubciXXXXXXXXXXXXXXXXXXXXXXXX.

CrowdSec CTI

1.2.3.5 will get ban for next 30s for triggering manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX' on machine githubciXXXXXXXXXXXXXXXXXXXXXXXX.

CrowdSec CTI

EOT diff --git a/test/bats/80_alerts.bats b/test/bats/80_alerts.bats index e0fdcb02271..6d84c1a1fce 100644 --- a/test/bats/80_alerts.bats +++ b/test/bats/80_alerts.bats @@ -73,9 +73,9 @@ teardown() { rune -0 cscli alerts list -o raw <(output) rune -0 grep 10.20.30.40 <(output) rune -0 cut -d, -f1 <(output) - ALERT_ID="${output}" + ALERT_ID="$output" - rune -0 cscli alerts inspect "${ALERT_ID}" -o human + rune -0 cscli alerts inspect "$ALERT_ID" -o human rune -0 plaintext < <(output) assert_line --regexp '^#+$' assert_line --regexp "^ - ID *: ${ALERT_ID}$" @@ -93,10 +93,10 @@ teardown() { assert_line --regexp "^.* ID .* scope:value .* action .* expiration .* created_at .*$" assert_line --regexp "^.* Ip:10.20.30.40 .* ban .*$" - rune -0 cscli alerts inspect "${ALERT_ID}" -o human --details + rune -0 cscli alerts inspect "$ALERT_ID" -o human --details # XXX can we have something here? - rune -0 cscli alerts inspect "${ALERT_ID}" -o raw + rune -0 cscli alerts inspect "$ALERT_ID" -o raw assert_line --regexp "^ *capacity: 0$" assert_line --regexp "^ *id: ${ALERT_ID}$" assert_line --regexp "^ *origin: cscli$" @@ -106,11 +106,11 @@ teardown() { assert_line --regexp "^ *type: ban$" assert_line --regexp "^ *value: 10.20.30.40$" - rune -0 cscli alerts inspect "${ALERT_ID}" -o json + rune -0 cscli alerts inspect "$ALERT_ID" -o json alert=${output} - rune jq -c '.decisions[] | [.origin,.scenario,.scope,.simulated,.type,.value]' <<<"${alert}" + rune jq -c '.decisions[] | [.origin,.scenario,.scope,.simulated,.type,.value]' <<<"$alert" assert_output --regexp "\[\"cscli\",\"manual 'ban' from 'githubciXXXXXXXXXXXXXXXXXXXXXXXX.*'\",\"Ip\",false,\"ban\",\"10.20.30.40\"\]" - rune jq -c '.source' <<<"${alert}" + rune jq -c '.source' <<<"$alert" assert_json '{ip:"10.20.30.40",scope:"Ip",value:"10.20.30.40"}' } @@ -188,7 +188,7 @@ teardown() { rune -0 cscli decisions add -i 10.20.30.40 -t ban rune -9 cscli decisions list --ip 10.20.30.40 -o json rune -9 jq -r '.[].decisions[].id' <(output) - DECISION_ID="${output}" + DECISION_ID="$output" ./instance-crowdsec stop rune -0 ./instance-db exec_sql "UPDATE decisions SET ... WHERE id=${DECISION_ID}" diff --git a/test/bats/81_alert_context.bats b/test/bats/81_alert_context.bats index df741f5f99c..69fb4158ffd 100644 --- a/test/bats/81_alert_context.bats +++ b/test/bats/81_alert_context.bats @@ -32,8 +32,8 @@ teardown() { #---------- @test "$FILE 1.1.1.172 has context" { - tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) - touch "${tmpfile}" + tmpfile=$(TMPDIR="$BATS_TEST_TMPDIR" mktemp) + touch "$tmpfile" ACQUIS_YAML=$(config_get '.crowdsec_service.acquisition_path') @@ -61,9 +61,9 @@ teardown() { ./instance-crowdsec start sleep 2 - fake_log >>"${tmpfile}" + fake_log >>"$tmpfile" sleep 2 - rm -f -- "${tmpfile}" + rm -f -- "$tmpfile" rune -0 cscli alerts list -o json rune -0 jq '.[0].id' <(output) diff --git a/test/bats/90_decisions.bats b/test/bats/90_decisions.bats index 8a2b9d3ae6f..b892dc84015 100644 --- a/test/bats/90_decisions.bats +++ b/test/bats/90_decisions.bats @@ -31,7 +31,6 @@ teardown() { @test "'decisions add' requires parameters" { rune -1 cscli decisions add - assert_line "Usage:" assert_stderr --partial "missing arguments, a value is required (--ip, --range or --scope and --value)" rune -1 cscli decisions add -o json @@ -109,12 +108,12 @@ teardown() { # invalid json rune -1 cscli decisions import -i - <<<'{"blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' # json with extra data rune -1 cscli decisions import -i - <<<'{"values":"1.2.3.4","blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' #---------- # CSV @@ -166,7 +165,7 @@ teardown() { # silently discarding (but logging) invalid decisions rune -0 cscli alerts delete --all - truncate -s 0 "${LOGFILE}" + truncate -s 0 "$LOGFILE" rune -0 cscli decisions import -i - --format values <<-EOT whatever @@ -180,9 +179,8 @@ teardown() { # disarding only some invalid decisions - rune -0 cscli alerts delete --all - truncate -s 0 "${LOGFILE}" + truncate -s 0 "$LOGFILE" rune -0 cscli decisions import -i - --format values <<-EOT 1.2.3.4 diff --git a/test/bats/97_ipv4_single.bats b/test/bats/97_ipv4_single.bats index 1ada1c4646b..b709930e2e5 100644 --- a/test/bats/97_ipv4_single.bats +++ b/test/bats/97_ipv4_single.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,11 +20,6 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { @@ -37,7 +30,7 @@ api() { } @test "API - first decisions list: must be empty" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' assert_output 'null' } @@ -53,7 +46,7 @@ api() { } @test "API - all decisions" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -c '[ . | length, .[0].value ]' <(output) assert_output '[1,"1.2.3.4"]' } @@ -67,7 +60,7 @@ api() { } @test "API - decision for 1.2.3.4" { - rune -0 api '/v1/decisions?ip=1.2.3.4' + rune -0 curl-with-key '/v1/decisions?ip=1.2.3.4' rune -0 jq -r '.[0].value' <(output) assert_output '1.2.3.4' } @@ -78,7 +71,7 @@ api() { } @test "API - decision for 1.2.3.5" { - rune -0 api '/v1/decisions?ip=1.2.3.5' + rune -0 curl-with-key '/v1/decisions?ip=1.2.3.5' assert_output 'null' } @@ -90,7 +83,7 @@ api() { } @test "API - decision for 1.2.3.0/24" { - rune -0 api '/v1/decisions?range=1.2.3.0/24' + rune -0 curl-with-key '/v1/decisions?range=1.2.3.0/24' assert_output 'null' } @@ -101,7 +94,7 @@ api() { } @test "API - decisions where IP in 1.2.3.0/24" { - rune -0 api '/v1/decisions?range=1.2.3.0/24&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1.2.3.0/24&contains=false' rune -0 jq -r '.[0].value' <(output) assert_output '1.2.3.4' } diff --git a/test/bats/97_ipv6_single.bats b/test/bats/97_ipv6_single.bats index ffbfc125b24..c7aea030f9c 100644 --- a/test/bats/97_ipv6_single.bats +++ b/test/bats/97_ipv6_single.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -19,12 +17,7 @@ teardown_file() { setup() { load "../lib/setup.sh" - if is_db_mysql; then sleep 0.3; fi -} - -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" + if is_db_mysql; then sleep 0.5; fi } #---------- @@ -48,7 +41,7 @@ api() { } @test "API - all decisions" { - rune -0 api "/v1/decisions" + rune -0 curl-with-key "/v1/decisions" rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @@ -60,7 +53,7 @@ api() { } @test "API - decisions for ip 1111:2222:3333:4444:5555:6666:7777:888" { - rune -0 api '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8888' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @@ -71,7 +64,7 @@ api() { } @test "API - decisions for ip 1211:2222:3333:4444:5555:6666:7777:888" { - rune -0 api '/v1/decisions?ip=1211:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=1211:2222:3333:4444:5555:6666:7777:8888' assert_output 'null' } @@ -81,7 +74,7 @@ api() { } @test "API - decisions for ip 1111:2222:3333:4444:5555:6666:7777:8887" { - rune -0 api '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8887' + rune -0 curl-with-key '/v1/decisions?ip=1111:2222:3333:4444:5555:6666:7777:8887' assert_output 'null' } @@ -91,7 +84,7 @@ api() { } @test "API - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48' assert_output 'null' } @@ -102,7 +95,7 @@ api() { } @test "API - decisions for ip/range in 1111:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48&&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/48&&contains=false' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } @@ -113,7 +106,7 @@ api() { } @test "API - decisions for range 1111:2222:3333:4444:5555:6666:7777:8888/64" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64' assert_output 'null' } @@ -124,7 +117,7 @@ api() { } @test "API - decisions for ip/range in 1111:2222:3333:4444:5555:6666:7777:8888/64" { - rune -0 api '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64&&contains=false' + rune -0 curl-with-key '/v1/decisions?range=1111:2222:3333:4444:5555:6666:7777:8888/64&&contains=false' rune -0 jq -r '.[].value' <(output) assert_output '1111:2222:3333:4444:5555:6666:7777:8888' } diff --git a/test/bats/98_ipv4_range.bats b/test/bats/98_ipv4_range.bats index b0f6f482944..c85e40267f3 100644 --- a/test/bats/98_ipv4_range.bats +++ b/test/bats/98_ipv4_range.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,11 +20,6 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { @@ -48,7 +41,7 @@ api() { } @test "API - all decisions" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @@ -62,7 +55,7 @@ api() { } @test "API - decisions for ip 4.4.4." { - rune -0 api '/v1/decisions?ip=4.4.4.3' + rune -0 curl-with-key '/v1/decisions?ip=4.4.4.3' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @@ -73,7 +66,7 @@ api() { } @test "API - decisions for ip contained in 4.4.4." { - rune -0 api '/v1/decisions?ip=4.4.4.4&contains=false' + rune -0 curl-with-key '/v1/decisions?ip=4.4.4.4&contains=false' assert_output 'null' } @@ -83,7 +76,7 @@ api() { } @test "API - decisions for ip 5.4.4." { - rune -0 api '/v1/decisions?ip=5.4.4.3' + rune -0 curl-with-key '/v1/decisions?ip=5.4.4.3' assert_output 'null' } @@ -93,7 +86,7 @@ api() { } @test "API - decisions for range 4.4.0.0/1" { - rune -0 api '/v1/decisions?range=4.4.0.0/16' + rune -0 curl-with-key '/v1/decisions?range=4.4.0.0/16' assert_output 'null' } @@ -104,7 +97,7 @@ api() { } @test "API - decisions for ip/range in 4.4.0.0/1" { - rune -0 api '/v1/decisions?range=4.4.0.0/16&contains=false' + rune -0 curl-with-key '/v1/decisions?range=4.4.0.0/16&contains=false' rune -0 jq -r '.[0].value' <(output) assert_output '4.4.4.0/24' } @@ -118,7 +111,7 @@ api() { } @test "API - decisions for range 4.4.4.2/2" { - rune -0 api '/v1/decisions?range=4.4.4.2/28' + rune -0 curl-with-key '/v1/decisions?range=4.4.4.2/28' rune -0 jq -r '.[].value' <(output) assert_output '4.4.4.0/24' } @@ -129,6 +122,6 @@ api() { } @test "API - decisions for range 4.4.3.2/2" { - rune -0 api '/v1/decisions?range=4.4.3.2/28' + rune -0 curl-with-key '/v1/decisions?range=4.4.3.2/28' assert_output 'null' } diff --git a/test/bats/98_ipv6_range.bats b/test/bats/98_ipv6_range.bats index d3c347583da..531122a5533 100644 --- a/test/bats/98_ipv6_range.bats +++ b/test/bats/98_ipv6_range.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -22,11 +20,6 @@ setup() { if is_db_mysql; then sleep 0.3; fi } -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - #---------- @test "cli - first decisions list: must be empty" { @@ -48,7 +41,7 @@ api() { } @test "API - all decisions (2)" { - rune -0 api '/v1/decisions' + rune -0 curl-with-key '/v1/decisions' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @@ -62,7 +55,7 @@ api() { } @test "API - decisions for ip aaaa:2222:3333:4444:5555:6666:7777:8888" { - rune -0 api '/v1/decisions?ip=aaaa:2222:3333:4444:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=aaaa:2222:3333:4444:5555:6666:7777:8888' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @@ -73,7 +66,7 @@ api() { } @test "API - decisions for ip aaaa:2222:3333:4445:5555:6666:7777:8888" { - rune -0 api '/v1/decisions?ip=aaaa:2222:3333:4445:5555:6666:7777:8888' + rune -0 curl-with-key '/v1/decisions?ip=aaaa:2222:3333:4445:5555:6666:7777:8888' assert_output 'null' } @@ -83,7 +76,7 @@ api() { } @test "API - decisions for ip aaa1:2222:3333:4444:5555:6666:7777:8887" { - rune -0 api '/v1/decisions?ip=aaa1:2222:3333:4444:5555:6666:7777:8887' + rune -0 curl-with-key '/v1/decisions?ip=aaa1:2222:3333:4444:5555:6666:7777:8887' assert_output 'null' } @@ -96,7 +89,7 @@ api() { } @test "API - decisions for range aaaa:2222:3333:4444:5555::/80" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555::/80' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @@ -108,7 +101,7 @@ api() { } @test "API - decisions for range aaaa:2222:3333:4441:5555::/80" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4441:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4441:5555::/80' assert_output 'null' } @@ -118,7 +111,7 @@ api() { } @test "API - decisions for range aaa1:2222:3333:4444:5555::/80" { - rune -0 api '/v1/decisions?range=aaa1:2222:3333:4444:5555::/80' + rune -0 curl-with-key '/v1/decisions?range=aaa1:2222:3333:4444:5555::/80' assert_output 'null' } @@ -130,7 +123,7 @@ api() { } @test "API - decisions for range aaaa:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48' assert_output 'null' } @@ -141,7 +134,7 @@ api() { } @test "API - decisions for ip/range in aaaa:2222:3333:4444:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48&contains=false' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4444:5555:6666:7777:8888/48&contains=false' rune -0 jq -r '.[].value' <(output) assert_output 'aaaa:2222:3333:4444::/64' } @@ -152,7 +145,7 @@ api() { } @test "API - decisions for ip/range in aaaa:2222:3333:4445:5555:6666:7777:8888/48" { - rune -0 api '/v1/decisions?range=aaaa:2222:3333:4445:5555:6666:7777:8888/48' + rune -0 curl-with-key '/v1/decisions?range=aaaa:2222:3333:4445:5555:6666:7777:8888/48' assert_output 'null' } @@ -170,7 +163,7 @@ api() { } @test "API - decisions for ip in bbbb:db8:0000:0000:0000:6fff:ffff:ffff" { - rune -0 api '/v1/decisions?ip=bbbb:db8:0000:0000:0000:6fff:ffff:ffff' + rune -0 curl-with-key '/v1/decisions?ip=bbbb:db8:0000:0000:0000:6fff:ffff:ffff' rune -0 jq -r '.[].value' <(output) assert_output 'bbbb:db8::/81' } @@ -181,7 +174,7 @@ api() { } @test "API - decisions for ip in bbbb:db8:0000:0000:0000:8fff:ffff:ffff" { - rune -0 api '/v1/decisions?ip=bbbb:db8:0000:0000:0000:8fff:ffff:ffff' + rune -0 curl-with-key '/v1/decisions?ip=bbbb:db8:0000:0000:0000:8fff:ffff:ffff' assert_output 'null' } diff --git a/test/bats/99_lapi-stream-mode-scenario.bats b/test/bats/99_lapi-stream-mode-scenario.bats index 9b4d562f3c9..32c346061d1 100644 --- a/test/bats/99_lapi-stream-mode-scenario.bats +++ b/test/bats/99_lapi-stream-mode-scenario.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -24,16 +22,10 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key:${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - output_new_decisions() { jq -c '.new | map(select(.origin!="CAPI")) | .[] | del(.id) | (.. | .duration?) |= capture("(?[[:digit:]]+h[[:digit:]]+m)").d' <(output) | sort } - @test "adding decisions with different duration, scenario, origin" { # origin: test rune -0 cscli decisions add -i 127.0.0.1 -d 1h -R crowdsecurity/test @@ -62,7 +54,7 @@ output_new_decisions() { } @test "test startup" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -71,7 +63,7 @@ output_new_decisions() { } @test "test startup with scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"another_origin","scenario":"crowdsecurity/ssh_bf","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -80,7 +72,7 @@ output_new_decisions() { } @test "test startup with multiple scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"another_origin","scenario":"crowdsecurity/ssh_bf","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -89,12 +81,12 @@ output_new_decisions() { } @test "test startup with unknown scenarios containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=unknown" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=unknown" assert_output '{"deleted":null,"new":null}' } @test "test startup with scenarios containing and not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=test&scenarios_not_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=test&scenarios_not_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -103,7 +95,7 @@ output_new_decisions() { } @test "test startup with scenarios containing and not containing 2" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_containing=longest&scenarios_not_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_containing=longest&scenarios_not_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"4h59m","origin":"test","scenario":"crowdsecurity/longest","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -111,7 +103,7 @@ output_new_decisions() { } @test "test startup with scenarios not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -120,7 +112,7 @@ output_new_decisions() { } @test "test startup with multiple scenarios not containing" { - rune -0 api "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh_bf,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"4h59m","origin":"test","scenario":"crowdsecurity/longest","scope":"Ip","type":"ban","value":"127.0.0.1"} @@ -128,7 +120,7 @@ output_new_decisions() { } @test "test startup with origins parameter" { - rune -0 api "/v1/decisions/stream?startup=true&origins=another_origin" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=another_origin" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"1h59m","origin":"another_origin","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -137,7 +129,7 @@ output_new_decisions() { } @test "test startup with multiple origins parameter" { - rune -0 api "/v1/decisions/stream?startup=true&origins=another_origin,test" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=another_origin,test" rune -0 output_new_decisions assert_output - <<-EOT {"duration":"2h59m","origin":"test","scenario":"crowdsecurity/test","scope":"Ip","type":"ban","value":"127.0.0.2"} @@ -146,7 +138,7 @@ output_new_decisions() { } @test "test startup with unknown origins" { - rune -0 api "/v1/decisions/stream?startup=true&origins=unknown" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&origins=unknown" assert_output '{"deleted":null,"new":null}' } @@ -230,4 +222,3 @@ output_new_decisions() { # NewChecks: []DecisionCheck{}, # }, #} - diff --git a/test/bats/99_lapi-stream-mode-scopes.bats b/test/bats/99_lapi-stream-mode-scopes.bats index a1d01c489e6..67badebea0e 100644 --- a/test/bats/99_lapi-stream-mode-scopes.bats +++ b/test/bats/99_lapi-stream-mode-scopes.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -23,11 +21,6 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - @test "adding decisions for multiple scopes" { rune -0 cscli decisions add -i '1.2.3.6' assert_stderr --partial 'Decision successfully added' @@ -36,28 +29,28 @@ api() { } @test "stream start (implicit ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' refute_output --partial 'toto' } @test "stream start (explicit ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=ip" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=ip" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' refute_output --partial 'toto' } @test "stream start (user scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=user" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=user" rune -0 jq -r '.new' <(output) refute_output --partial '1.2.3.6' assert_output --partial 'toto' } @test "stream start (user+ip scope)" { - rune -0 api "/v1/decisions/stream?startup=true&scopes=user,ip" + rune -0 curl-with-key "/v1/decisions/stream?startup=true&scopes=user,ip" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.6' assert_output --partial 'toto' diff --git a/test/bats/99_lapi-stream-mode.bats b/test/bats/99_lapi-stream-mode.bats index 08ddde42c5f..b3ee8a434ff 100644 --- a/test/bats/99_lapi-stream-mode.bats +++ b/test/bats/99_lapi-stream-mode.bats @@ -9,8 +9,6 @@ setup_file() { ./instance-crowdsec start API_KEY=$(cscli bouncers add testbouncer -o raw) export API_KEY - CROWDSEC_API_URL="http://localhost:8080" - export CROWDSEC_API_URL } teardown_file() { @@ -23,11 +21,6 @@ setup() { #---------- -api() { - URI="$1" - curl -s -H "X-Api-Key: ${API_KEY}" "${CROWDSEC_API_URL}${URI}" -} - @test "adding decisions for multiple ips" { rune -0 cscli decisions add -i '1111:2222:3333:4444:5555:6666:7777:8888' assert_stderr --partial 'Decision successfully added' @@ -38,7 +31,7 @@ api() { } @test "stream start" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" if is_db_mysql; then sleep 3; fi rune -0 jq -r '.new' <(output) assert_output --partial '1111:2222:3333:4444:5555:6666:7777:8888' @@ -49,7 +42,7 @@ api() { @test "stream cont (add)" { rune -0 cscli decisions add -i '1.2.3.5' if is_db_mysql; then sleep 3; fi - rune -0 api "/v1/decisions/stream" + rune -0 curl-with-key "/v1/decisions/stream" rune -0 jq -r '.new' <(output) assert_output --partial '1.2.3.5' } @@ -57,13 +50,13 @@ api() { @test "stream cont (del)" { rune -0 cscli decisions delete -i '1.2.3.4' if is_db_mysql; then sleep 3; fi - rune -0 api "/v1/decisions/stream" + rune -0 curl-with-key "/v1/decisions/stream" rune -0 jq -r '.deleted' <(output) assert_output --partial '1.2.3.4' } @test "stream restart" { - rune -0 api "/v1/decisions/stream?startup=true" + rune -0 curl-with-key "/v1/decisions/stream?startup=true" api_out=${output} rune -0 jq -r '.deleted' <(output) assert_output --partial '1.2.3.4' diff --git a/test/bats/testdata/cfssl/agent.json b/test/bats/testdata/cfssl/agent.json index 693e3aa512b..47b342e5a40 100644 --- a/test/bats/testdata/cfssl/agent.json +++ b/test/bats/testdata/cfssl/agent.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "agent-ou", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/agent_invalid.json b/test/bats/testdata/cfssl/agent_invalid.json index c61d4dee677..eb7db8d96fb 100644 --- a/test/bats/testdata/cfssl/agent_invalid.json +++ b/test/bats/testdata/cfssl/agent_invalid.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "this-is-not-the-ou-youre-looking-for", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/bouncer.json b/test/bats/testdata/cfssl/bouncer.json index 9a07f576610..bf642c48ad8 100644 --- a/test/bats/testdata/cfssl/bouncer.json +++ b/test/bats/testdata/cfssl/bouncer.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "bouncer-ou", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/bouncer_invalid.json b/test/bats/testdata/cfssl/bouncer_invalid.json index c61d4dee677..eb7db8d96fb 100644 --- a/test/bats/testdata/cfssl/bouncer_invalid.json +++ b/test/bats/testdata/cfssl/bouncer_invalid.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,5 +12,5 @@ "OU": "this-is-not-the-ou-youre-looking-for", "ST": "France" } - ] - } \ No newline at end of file + ] +} diff --git a/test/bats/testdata/cfssl/ca.json b/test/bats/testdata/cfssl/ca.json deleted file mode 100644 index ed907e0375b..00000000000 --- a/test/bats/testdata/cfssl/ca.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "CN": "CrowdSec Test CA", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ - { - "C": "FR", - "L": "Paris", - "O": "Crowdsec", - "OU": "Crowdsec", - "ST": "France" - } - ] -} \ No newline at end of file diff --git a/test/bats/testdata/cfssl/intermediate.json b/test/bats/testdata/cfssl/ca_intermediate.json similarity index 53% rename from test/bats/testdata/cfssl/intermediate.json rename to test/bats/testdata/cfssl/ca_intermediate.json index 3996ce6e189..34f1583da06 100644 --- a/test/bats/testdata/cfssl/intermediate.json +++ b/test/bats/testdata/cfssl/ca_intermediate.json @@ -1,10 +1,10 @@ { - "CN": "CrowdSec Test CA Intermediate", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "CrowdSec Test CA Intermediate", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,8 +12,8 @@ "OU": "Crowdsec Intermediate", "ST": "France" } - ], - "ca": { + ], + "ca": { "expiry": "42720h" } - } \ No newline at end of file +} diff --git a/test/bats/testdata/cfssl/ca_root.json b/test/bats/testdata/cfssl/ca_root.json new file mode 100644 index 00000000000..a0d64796637 --- /dev/null +++ b/test/bats/testdata/cfssl/ca_root.json @@ -0,0 +1,16 @@ +{ + "CN": "CrowdSec Test CA", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "Crowdsec", + "ST": "France" + } + ] +} diff --git a/test/bats/testdata/cfssl/profiles.json b/test/bats/testdata/cfssl/profiles.json index d0dfced4a47..47611beb64c 100644 --- a/test/bats/testdata/cfssl/profiles.json +++ b/test/bats/testdata/cfssl/profiles.json @@ -1,44 +1,37 @@ { - "signing": { - "default": { + "signing": { + "default": { + "expiry": "8760h" + }, + "profiles": { + "intermediate_ca": { + "usages": [ + "signing", + "key encipherment", + "cert sign", + "crl sign", + "server auth", + "client auth" + ], + "expiry": "8760h", + "ca_constraint": { + "is_ca": true, + "max_path_len": 0, + "max_path_len_zero": true + } + }, + "server": { + "usages": [ + "server auth" + ], "expiry": "8760h" }, - "profiles": { - "intermediate_ca": { - "usages": [ - "signing", - "digital signature", - "key encipherment", - "cert sign", - "crl sign", - "server auth", - "client auth" - ], - "expiry": "8760h", - "ca_constraint": { - "is_ca": true, - "max_path_len": 0, - "max_path_len_zero": true - } - }, - "server": { - "usages": [ - "signing", - "digital signing", - "key encipherment", - "server auth" - ], - "expiry": "8760h" - }, - "client": { - "usages": [ - "signing", - "digital signature", - "key encipherment", - "client auth" - ], - "expiry": "8760h" - } + "client": { + "usages": [ + "client auth" + ], + "expiry": "8760h" } } - } \ No newline at end of file + } +} diff --git a/test/bats/testdata/cfssl/server.json b/test/bats/testdata/cfssl/server.json index 37018259e95..cce97037ca7 100644 --- a/test/bats/testdata/cfssl/server.json +++ b/test/bats/testdata/cfssl/server.json @@ -1,10 +1,10 @@ { - "CN": "localhost", - "key": { - "algo": "rsa", - "size": 2048 - }, - "names": [ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ { "C": "FR", "L": "Paris", @@ -12,9 +12,9 @@ "OU": "Crowdsec Server", "ST": "France" } - ], - "hosts": [ - "127.0.0.1", - "localhost" - ] - } \ No newline at end of file + ], + "hosts": [ + "127.0.0.1", + "localhost" + ] +} diff --git a/test/bin/mock-http.py b/test/bin/mock-http.py index 3f26271b400..d11a4ebf717 100644 --- a/test/bin/mock-http.py +++ b/test/bin/mock-http.py @@ -6,6 +6,7 @@ from http.server import HTTPServer, BaseHTTPRequestHandler + class RequestHandler(BaseHTTPRequestHandler): def do_POST(self): request_path = self.path @@ -18,7 +19,7 @@ def do_POST(self): } print(json.dumps(log)) self.send_response(200) - self.send_header('Content-type','application/json') + self.send_header('Content-type', 'application/json') self.end_headers() self.wfile.write(json.dumps({}).encode()) self.wfile.flush() @@ -27,6 +28,7 @@ def do_POST(self): def log_message(self, format, *args): return + def main(argv): try: port = int(argv[1]) @@ -42,6 +44,6 @@ def main(argv): return 0 -if __name__ == "__main__" : +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) sys.exit(main(sys.argv)) diff --git a/test/bin/preload-hub-items b/test/bin/preload-hub-items index 14e9cff998c..79e20efbea2 100755 --- a/test/bin/preload-hub-items +++ b/test/bin/preload-hub-items @@ -9,34 +9,24 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # pre-download everything but don't install anything -echo -n "Purging existing hub..." +echo "Pre-downloading Hub content..." -types=$("$CSCLI" hub types -o raw) - -for itemtype in $types; do - "$CSCLI" "${itemtype}" delete --all --error --purge --force -done - -echo " done." +start=$(date +%s%N) -echo -n "Pre-downloading Hub content..." +types=$("$CSCLI" hub types -o raw) for itemtype in $types; do - ALL_ITEMS=$("$CSCLI" "$itemtype" list -a -o json | jq --arg itemtype "$itemtype" -r '.[$itemtype][].name') + ALL_ITEMS=$("$CSCLI" "$itemtype" list -a -o json | itemtype="$itemtype" yq '.[env(itemtype)][] | .name') if [[ -n "${ALL_ITEMS}" ]]; then #shellcheck disable=SC2086 "$CSCLI" "$itemtype" install \ $ALL_ITEMS \ - --download-only \ - --error + --download-only fi done -# XXX: download-only works only for collections, not for parsers, scenarios, postoverflows. -# so we have to delete the links manually, and leave the downloaded files in place - -for itemtype in $types; do - "$CSCLI" "$itemtype" delete --all --error -done +elapsed=$((($(date +%s%N) - start)/1000000)) +# bash only does integer arithmetic, we could use bc or have some fun with sed +elapsed=$(echo "$elapsed" | sed -e 's/...$/.&/;t' -e 's/.$/.0&/') -echo " done." +echo " done in $elapsed secs." diff --git a/test/bin/remove-all-hub-items b/test/bin/remove-all-hub-items new file mode 100755 index 00000000000..981602b775a --- /dev/null +++ b/test/bin/remove-all-hub-items @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -eu + +# shellcheck disable=SC1007 +THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +# shellcheck disable=SC1091 +. "${THIS_DIR}/../.environment.sh" + +# pre-download everything but don't install anything + +echo "Pre-downloading Hub content..." + +types=$("$CSCLI" hub types -o raw) + +for itemtype in $types; do + "$CSCLI" "$itemtype" remove --all --force +done + +echo " done." diff --git a/test/bin/wait-for b/test/bin/wait-for index 6c6fdd5ce2b..b226783d44b 100755 --- a/test/bin/wait-for +++ b/test/bin/wait-for @@ -39,7 +39,7 @@ async def monitor(cmd, args, want_out, want_err, timeout): status = None - async def read_stream(p, stream, outstream, pattern): + async def read_stream(stream, outstream, pattern): nonlocal status if stream is None: return @@ -84,8 +84,8 @@ async def monitor(cmd, args, want_out, want_err, timeout): await asyncio.wait_for( asyncio.wait([ asyncio.create_task(process.wait()), - asyncio.create_task(read_stream(process, process.stdout, sys.stdout, out_regex)), - asyncio.create_task(read_stream(process, process.stderr, sys.stderr, err_regex)) + asyncio.create_task(read_stream(process.stdout, sys.stdout, out_regex)), + asyncio.create_task(read_stream(process.stderr, sys.stderr, err_regex)) ]), timeout) if status is None: status = process.returncode diff --git a/test/bin/wait-for-port b/test/bin/wait-for-port index 15408b8e5a0..72f26bf409c 100755 --- a/test/bin/wait-for-port +++ b/test/bin/wait-for-port @@ -54,10 +54,6 @@ def main(argv): if not args.quiet: write_error(ex) sys.exit(1) - else: - sys.exit(0) - - sys.exit(1) if __name__ == "__main__": diff --git a/test/disable-capi b/test/disable-capi index f19bef5314c..b847accae48 100755 --- a/test/disable-capi +++ b/test/disable-capi @@ -5,4 +5,4 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck disable=SC1091 . "${THIS_DIR}/.environment.sh" -yq e 'del(.api.server.online_client)' -i "${CONFIG_YAML}" +yq e 'del(.api.server.online_client)' -i "$CONFIG_YAML" diff --git a/test/enable-capi b/test/enable-capi index ddbf8764c44..59980e6a059 100755 --- a/test/enable-capi +++ b/test/enable-capi @@ -5,7 +5,7 @@ THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck disable=SC1091 . "${THIS_DIR}/.environment.sh" -online_api_credentials="$(dirname "${CONFIG_YAML}")/online_api_credentials.yaml" +online_api_credentials="$(dirname "$CONFIG_YAML")/online_api_credentials.yaml" export online_api_credentials -yq e '.api.server.online_client.credentials_path=strenv(online_api_credentials)' -i "${CONFIG_YAML}" +yq e '.api.server.online_client.credentials_path=strenv(online_api_credentials)' -i "$CONFIG_YAML" diff --git a/test/instance-crowdsec b/test/instance-crowdsec index d87145c3881..f0cef729693 100755 --- a/test/instance-crowdsec +++ b/test/instance-crowdsec @@ -2,15 +2,15 @@ #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh backend_script="./lib/init/crowdsec-${INIT_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then +if [[ ! -x "$backend_script" ]]; then echo "unknown init system '${INIT_BACKEND}'" >&2 exit 1 fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-data b/test/instance-data index 02742b4ec85..e7fd05a9e54 100755 --- a/test/instance-data +++ b/test/instance-data @@ -1,16 +1,26 @@ #!/usr/bin/env bash +set -eu + +die() { + echo >&2 "$@" + exit 1 +} + #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh +if [[ -f "$LOCAL_INIT_DIR/.lock" ]] && [[ "$1" != "unlock" ]]; then + die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" +fi + backend_script="./lib/config/config-${CONFIG_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then - echo "unknown config backend '${CONFIG_BACKEND}'" >&2 - exit 1 +if [[ ! -x "$backend_script" ]]; then + die "unknown config backend '${CONFIG_BACKEND}'" fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-db b/test/instance-db index fbbc18dc433..de09465bc32 100755 --- a/test/instance-db +++ b/test/instance-db @@ -2,7 +2,7 @@ #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" || exit 1 +cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh @@ -10,9 +10,9 @@ cd "${THIS_DIR}" || exit 1 backend_script="./lib/db/instance-${DB_BACKEND}" -if [[ ! -x "${backend_script}" ]]; then +if [[ ! -x "$backend_script" ]]; then echo "unknown database '${DB_BACKEND}'" >&2 exit 1 fi -exec "${backend_script}" "$@" +exec "$backend_script" "$@" diff --git a/test/instance-mock-http b/test/instance-mock-http index cca19b79e3e..b5a56d3489d 100755 --- a/test/instance-mock-http +++ b/test/instance-mock-http @@ -13,7 +13,7 @@ about() { #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) -cd "${THIS_DIR}" +cd "$THIS_DIR" # shellcheck disable=SC1091 . ./.environment.sh @@ -31,7 +31,7 @@ DAEMON_PID=${PID_DIR}/mock-http.pid start_instance() { [[ $# -lt 1 ]] && about daemonize \ - -p "${DAEMON_PID}" \ + -p "$DAEMON_PID" \ -e "${LOG_DIR}/mock-http.err" \ -o "${LOG_DIR}/mock-http.out" \ /usr/bin/env python3 -u "${THIS_DIR}/bin/mock-http.py" "$1" @@ -40,10 +40,10 @@ start_instance() { } stop_instance() { - if [[ -f "${DAEMON_PID}" ]]; then + if [[ -f "$DAEMON_PID" ]]; then # terminate with extreme prejudice, all the application data will be thrown away anyway - kill -9 "$(cat "${DAEMON_PID}")" > /dev/null 2>&1 - rm -f -- "${DAEMON_PID}" + kill -9 "$(cat "$DAEMON_PID")" > /dev/null 2>&1 + rm -f -- "$DAEMON_PID" fi } diff --git a/test/lib/color-formatter b/test/lib/color-formatter new file mode 100755 index 00000000000..aee8d750698 --- /dev/null +++ b/test/lib/color-formatter @@ -0,0 +1,355 @@ +#!/usr/bin/env bash + +# +# Taken from pretty formatter, minus the cursor movements. +# Used in gihtub workflows CI where color is allowed. +# + +set -e + +# shellcheck source=lib/bats-core/formatter.bash +source "$BATS_ROOT/lib/bats-core/formatter.bash" + +BASE_PATH=. +BATS_ENABLE_TIMING= + +while [[ "$#" -ne 0 ]]; do + case "$1" in + -T) + BATS_ENABLE_TIMING="-T" + ;; + --base-path) + shift + normalize_base_path BASE_PATH "$1" + ;; + esac + shift +done + +update_count_column_width() { + count_column_width=$((${#count} * 2 + 2)) + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # additional space for ' in %s sec' + count_column_width=$((count_column_width + ${#SECONDS} + 8)) + fi + # also update dependent value + update_count_column_left +} + +update_screen_width() { + screen_width="$(tput cols)" + # also update dependent value + update_count_column_left +} + +update_count_column_left() { + count_column_left=$((screen_width - count_column_width)) +} + +# avoid unset variables +count=0 +screen_width=80 +update_count_column_width +#update_screen_width +test_result= + +#trap update_screen_width WINCH + +begin() { + test_result= # reset to avoid carrying over result state from previous test + line_backoff_count=0 + #go_to_column 0 + #update_count_column_width + #buffer_with_truncation $((count_column_left - 1)) ' %s' "$name" + #clear_to_end_of_line + #go_to_column $count_column_left + #if [[ -n "$BATS_ENABLE_TIMING" ]]; then + # buffer "%${#count}s/${count} in %s sec" "$index" "$SECONDS" + #else + # buffer "%${#count}s/${count}" "$index" + #fi + #go_to_column 1 + buffer "%${#count}s" "$index" +} + +finish_test() { + #move_up $line_backoff_count + #go_to_column 0 + buffer "$@" + if [[ -n "${TIMEOUT-}" ]]; then + set_color 2 + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer ' [%s (timeout: %s)]' "$TIMING" "$TIMEOUT" + else + buffer ' [timeout: %s]' "$TIMEOUT" + fi + else + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + set_color 2 + buffer ' [%s]' "$TIMING" + fi + fi + advance + move_down $((line_backoff_count - 1)) +} + +pass() { + local TIMING="${1:-}" + finish_test ' ✓ %s' "$name" + test_result=pass +} + +skip() { + local reason="$1" TIMING="${2:-}" + if [[ -n "$reason" ]]; then + reason=": $reason" + fi + finish_test ' - %s (skipped%s)' "$name" "$reason" + test_result=skip +} + +fail() { + local TIMING="${1:-}" + set_color 1 bold + finish_test ' ✗ %s' "$name" + test_result=fail +} + +timeout() { + local TIMING="${1:-}" + set_color 3 bold + TIMEOUT="${2:-}" finish_test ' ✗ %s' "$name" + test_result=timeout +} + +log() { + case ${test_result} in + pass) + clear_color + ;; + fail) + set_color 1 + ;; + timeout) + set_color 3 + ;; + esac + buffer ' %s\n' "$1" + clear_color +} + +summary() { + if [ "$failures" -eq 0 ]; then + set_color 2 bold + else + set_color 1 bold + fi + + buffer '\n%d test' "$count" + if [[ "$count" -ne 1 ]]; then + buffer 's' + fi + + buffer ', %d failure' "$failures" + if [[ "$failures" -ne 1 ]]; then + buffer 's' + fi + + if [[ "$skipped" -gt 0 ]]; then + buffer ', %d skipped' "$skipped" + fi + + if ((timed_out > 0)); then + buffer ', %d timed out' "$timed_out" + fi + + not_run=$((count - passed - failures - skipped - timed_out)) + if [[ "$not_run" -gt 0 ]]; then + buffer ', %d not run' "$not_run" + fi + + if [[ -n "$BATS_ENABLE_TIMING" ]]; then + buffer " in $SECONDS seconds" + fi + + buffer '\n' + clear_color +} + +buffer_with_truncation() { + local width="$1" + shift + local string + + # shellcheck disable=SC2059 + printf -v 'string' -- "$@" + + if [[ "${#string}" -gt "$width" ]]; then + buffer '%s...' "${string:0:$((width - 4))}" + else + buffer '%s' "$string" + fi +} + +move_up() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dA' "$1" + fi +} + +move_down() { + if [[ $1 -gt 0 ]]; then # avoid moving if we got 0 + buffer '\x1B[%dB' "$1" + fi +} + +go_to_column() { + local column="$1" + buffer '\x1B[%dG' $((column + 1)) +} + +clear_to_end_of_line() { + buffer '\x1B[K' +} + +advance() { + clear_to_end_of_line + buffer '\n' + clear_color +} + +set_color() { + local color="$1" + local weight=22 + + if [[ "${2:-}" == 'bold' ]]; then + weight=1 + fi + buffer '\x1B[%d;%dm' "$((30 + color))" "$weight" +} + +clear_color() { + buffer '\x1B[0m' +} + +_buffer= + +buffer() { + local content + # shellcheck disable=SC2059 + printf -v content -- "$@" + _buffer+="$content" +} + +prefix_buffer_with() { + local old_buffer="$_buffer" + _buffer='' + "$@" + _buffer="$_buffer$old_buffer" +} + +flush() { + printf '%s' "$_buffer" + _buffer= +} + +finish() { + flush + printf '\n' +} + +trap finish EXIT +trap '' INT + +bats_tap_stream_plan() { + count="$1" + index=0 + passed=0 + failures=0 + skipped=0 + timed_out=0 + name= + update_count_column_width +} + +bats_tap_stream_begin() { + index="$1" + name="$2" + begin + flush +} + +bats_tap_stream_ok() { + index="$1" + name="$2" + ((++passed)) + + pass "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_skipped() { + index="$1" + name="$2" + ((++skipped)) + skip "$3" "${BATS_FORMATTER_TEST_DURATION:-}" +} + +bats_tap_stream_not_ok() { + index="$1" + name="$2" + + if [[ ${BATS_FORMATTER_TEST_TIMEOUT-x} != x ]]; then + timeout "${BATS_FORMATTER_TEST_DURATION:-}" "${BATS_FORMATTER_TEST_TIMEOUT}s" + ((++timed_out)) + else + fail "${BATS_FORMATTER_TEST_DURATION:-}" + ((++failures)) + fi + +} + +bats_tap_stream_comment() { # + local scope=$2 + # count the lines we printed after the begin text, + if [[ $line_backoff_count -eq 0 && $scope == begin ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + log "$1" +} + +bats_tap_stream_suite() { + #test_file="$1" + line_backoff_count=0 + index= + # indicate filename for failures + local file_name="${1#"$BASE_PATH"}" + name="File $file_name" + set_color 4 bold + buffer "%s\n" "$file_name" + clear_color +} + +line_backoff_count=0 +bats_tap_stream_unknown() { # + local scope=$2 + # count the lines we printed after the begin text, (or after suite, in case of syntax errors) + if [[ $line_backoff_count -eq 0 && ($scope == begin || $scope == suite) ]]; then + # if this is the first line after begin, go down one line + buffer "\n" + ((++line_backoff_count)) # prefix-increment to avoid "error" due to returning 0 + fi + + ((++line_backoff_count)) + ((line_backoff_count += ${#1} / screen_width)) # account for linebreaks due to length + buffer "%s\n" "$1" + flush +} + +bats_parse_internal_extended_tap + +summary diff --git a/test/lib/config/config-global b/test/lib/config/config-global index 68346c18875..9b2b71c1dd1 100755 --- a/test/lib/config/config-global +++ b/test/lib/config/config-global @@ -58,6 +58,7 @@ config_prepare() { # remove trailing slash from CONFIG_DIR # since it's assumed to be missing during the tests yq e -i ' + .api.server.listen_socket="/run/crowdsec.sock" | .config_paths.config_dir |= sub("/$", "") ' "${CONFIG_DIR}/config.yaml" } @@ -69,7 +70,10 @@ make_init_data() { ./instance-db config-yaml ./instance-db setup - ./bin/preload-hub-items + # preload some content and data files + "$CSCLI" collections install crowdsecurity/linux --download-only + # sub-items did not respect --download-only + ./bin/remove-all-hub-items # when installed packages are always using sqlite, so no need to regenerate # local credz for sqlite diff --git a/test/lib/config/config-local b/test/lib/config/config-local index e3b7bc685d4..3e3c806b616 100755 --- a/test/lib/config/config-local +++ b/test/lib/config/config-local @@ -9,7 +9,7 @@ die() { } about() { - die "usage: ${script_name} [make | load | clean]" + die "usage: ${script_name} [make | load | lock | unlock | clean]" } #shellcheck disable=SC1007 @@ -57,7 +57,6 @@ config_generate() { cp ../config/profiles.yaml \ ../config/simulation.yaml \ - ../config/local_api_credentials.yaml \ ../config/online_api_credentials.yaml \ "${CONFIG_DIR}/" @@ -81,7 +80,6 @@ config_generate() { .common.daemonize=true | del(.common.pid_dir) | .common.log_level="info" | - .common.force_color_logs=true | .common.log_dir=strenv(LOG_DIR) | .config_paths.config_dir=strenv(CONFIG_DIR) | .config_paths.data_dir=strenv(DATA_DIR) | @@ -95,6 +93,7 @@ config_generate() { .db_config.db_path=strenv(DATA_DIR)+"/crowdsec.db" | .db_config.use_wal=true | .api.client.credentials_path=strenv(CONFIG_DIR)+"/local_api_credentials.yaml" | + .api.server.listen_socket=strenv(DATA_DIR)+"/crowdsec.sock" | .api.server.profiles_path=strenv(CONFIG_DIR)+"/profiles.yaml" | .api.server.console_path=strenv(CONFIG_DIR)+"/console.yaml" | del(.api.server.online_client) @@ -115,11 +114,15 @@ make_init_data() { ./instance-db config-yaml ./instance-db setup - "$CSCLI" --warning hub update + "$CSCLI" --warning hub update --with-content - ./bin/preload-hub-items + # preload some content and data files + "$CSCLI" collections install crowdsecurity/linux --download-only + # sub-items did not respect --download-only + ./bin/remove-all-hub-items - "$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --auto --force + # force TCP, the default would be unix socket + "$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --url http://127.0.0.1:8080 --auto --force mkdir -p "$LOCAL_INIT_DIR" @@ -134,7 +137,16 @@ make_init_data() { remove_init_data } +lock_init_data() { + touch "${LOCAL_INIT_DIR}/.lock" +} + +unlock_init_data() { + rm -f "${LOCAL_INIT_DIR}/.lock" +} + load_init_data() { + [[ -f "${LOCAL_INIT_DIR}/.lock" ]] && die "init data is locked" ./bin/assert-crowdsec-not-running || die "Cannot load fixture data." if [[ ! -f "${LOCAL_INIT_DIR}/init-config-data.tar" ]]; then @@ -164,6 +176,12 @@ case "$1" in load) load_init_data ;; + lock) + lock_init_data + ;; + unlock) + unlock_init_data + ;; clean) remove_init_data ;; diff --git a/test/lib/db/instance-mysql b/test/lib/db/instance-mysql index 6b40c84acba..df38f09761f 100755 --- a/test/lib/db/instance-mysql +++ b/test/lib/db/instance-mysql @@ -21,7 +21,7 @@ about() { check_requirements() { if ! command -v mysql >/dev/null; then - die "missing required program 'mysql' as a mysql client (package mariadb-client-core-10.6 on debian like system)" + die "missing required program 'mysql' as a mysql client (package mariadb-client on debian like system)" fi } diff --git a/test/lib/setup_file.sh b/test/lib/setup_file.sh index 1aca32fa6d0..39a084596e2 100755 --- a/test/lib/setup_file.sh +++ b/test/lib/setup_file.sh @@ -155,6 +155,11 @@ assert_log() { } export -f assert_log +cert_serial_number() { + cfssl certinfo -cert "$1" | jq -r '.serial_number' +} +export -f cert_serial_number + # Compare ignoring the key order, and allow "expected" without quoted identifiers. # Preserve the output variable in case the following commands require it. assert_json() { @@ -260,7 +265,7 @@ hub_strip_index() { local INDEX INDEX=$(config_get .config_paths.index_path) local hub_min - hub_min=$(jq <"$INDEX" 'del(..|.content?) | del(..|.long_description?) | del(..|.deprecated?) | del (..|.labels?)') + hub_min=$(jq <"$INDEX" 'del(..|.long_description?) | del(..|.deprecated?) | del (..|.labels?)') echo "$hub_min" >"$INDEX" } export -f hub_strip_index @@ -276,3 +281,62 @@ rune() { run --separate-stderr "$@" } export -f rune + +# call the lapi through unix socket +# the path (and query string) must be the first parameter, the others will be passed to curl +curl-socket() { + [[ -z "$1" ]] && { fail "${FUNCNAME[0]}: missing path"; } + local path=$1 + shift + local socket + socket=$(config_get '.api.server.listen_socket') + [[ -z "$socket" ]] && { fail "${FUNCNAME[0]}: missing .api.server.listen_socket"; } + # curl needs a fake hostname when using a unix socket + curl --unix-socket "$socket" "http://lapi$path" "$@" +} +export -f curl-socket + +# call the lapi through tcp +# the path (and query string) must be the first parameter, the others will be passed to curl +curl-tcp() { + [[ -z "$1" ]] && { fail "${FUNCNAME[0]}: missing path"; } + local path=$1 + shift + local cred + cred=$(config_get .api.client.credentials_path) + local base_url + base_url="$(yq '.url' < "$cred")" + curl "$base_url$path" "$@" +} +export -f curl-tcp + +# call the lapi through unix socket with an API_KEY (authenticates as a bouncer) +# after $1, pass throught extra arguments to curl +curl-with-key() { + [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; } + curl-tcp "$@" -sS --fail-with-body -H "X-Api-Key: $API_KEY" +} +export -f curl-with-key + +# call the lapi through unix socket with a TOKEN (authenticates as a machine) +# after $1, pass throught extra arguments to curl +curl-with-token() { + [[ -z "$TOKEN" ]] && { fail "${FUNCNAME[0]}: missing TOKEN"; } + # curl needs a fake hostname when using a unix socket + curl-tcp "$@" -sS --fail-with-body -H "Authorization: Bearer $TOKEN" +} +export -f curl-with-token + +# as a log processor, connect to lapi and get a token +lp-get-token() { + local cred + cred=$(config_get .api.client.credentials_path) + local resp + resp=$(yq -oj -I0 '{"machine_id":.login,"password":.password}' < "$cred" | curl-socket '/v1/watchers/login' -s -X POST --data-binary @-) + if [[ "$(yq -e '.code' <<<"$resp")" != 200 ]]; then + echo "login_lp: failed to login" >&3 + return 1 + fi + echo "$resp" | yq -r '.token' +} +export -f lp-get-token diff --git a/test/run-tests b/test/run-tests index 21b7a7320c5..957eb663b9c 100755 --- a/test/run-tests +++ b/test/run-tests @@ -10,35 +10,37 @@ die() { # shellcheck disable=SC1007 TEST_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck source=./.environment.sh -. "${TEST_DIR}/.environment.sh" +. "$TEST_DIR/.environment.sh" -"${TEST_DIR}/bin/check-requirements" +"$TEST_DIR/bin/check-requirements" echo "Running tests..." -echo "DB_BACKEND: ${DB_BACKEND}" -if [[ -z "${TEST_COVERAGE}" ]]; then +echo "DB_BACKEND: $DB_BACKEND" +if [[ -z "$TEST_COVERAGE" ]]; then echo "Coverage report: no" else echo "Coverage report: yes" fi -dump_backend="$(cat "${LOCAL_INIT_DIR}/.backend")" -if [[ "${DB_BACKEND}" != "${dump_backend}" ]]; then - die "Can't run with backend '${DB_BACKEND}' because the test data was build with '${dump_backend}'" +[[ -f "$LOCAL_INIT_DIR/.lock" ]] && die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" + +dump_backend="$(cat "$LOCAL_INIT_DIR/.backend")" +if [[ "$DB_BACKEND" != "$dump_backend" ]]; then + die "Can't run with backend '$DB_BACKEND' because the test data was build with '$dump_backend'" fi if [[ $# -ge 1 ]]; then echo "test files: $*" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ "$@" else - echo "test files: ${TEST_DIR}/bats ${TEST_DIR}/dyn-bats" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + echo "test files: $TEST_DIR/bats $TEST_DIR/dyn-bats" + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ - "${TEST_DIR}/bats" "${TEST_DIR}/dyn-bats" + "$TEST_DIR/bats" "$TEST_DIR/dyn-bats" fi diff --git a/wizard.sh b/wizard.sh index 598f0c765f0..6e215365f6c 100755 --- a/wizard.sh +++ b/wizard.sh @@ -18,7 +18,6 @@ NC='\033[0m' SILENT="false" DOCKER_MODE="false" -CROWDSEC_RUN_DIR="/var/run" CROWDSEC_LIB_DIR="/var/lib/crowdsec" CROWDSEC_USR_DIR="/usr/local/lib/crowdsec" CROWDSEC_DATA_DIR="${CROWDSEC_LIB_DIR}/data" @@ -82,12 +81,14 @@ SLACK_PLUGIN_BINARY="./cmd/notification-slack/notification-slack" SPLUNK_PLUGIN_BINARY="./cmd/notification-splunk/notification-splunk" EMAIL_PLUGIN_BINARY="./cmd/notification-email/notification-email" SENTINEL_PLUGIN_BINARY="./cmd/notification-sentinel/notification-sentinel" +FILE_PLUGIN_BINARY="./cmd/notification-file/notification-file" HTTP_PLUGIN_CONFIG="./cmd/notification-http/http.yaml" SLACK_PLUGIN_CONFIG="./cmd/notification-slack/slack.yaml" SPLUNK_PLUGIN_CONFIG="./cmd/notification-splunk/splunk.yaml" EMAIL_PLUGIN_CONFIG="./cmd/notification-email/email.yaml" SENTINEL_PLUGIN_CONFIG="./cmd/notification-sentinel/sentinel.yaml" +FILE_PLUGIN_CONFIG="./cmd/notification-file/file.yaml" BACKUP_DIR=$(mktemp -d) @@ -409,12 +410,14 @@ check_cs_version () { install_crowdsec() { mkdir -p "${CROWDSEC_DATA_DIR}" (cd config && find patterns -type f -exec install -Dm 644 "{}" "${CROWDSEC_CONFIG_PATH}/{}" \; && cd ../) || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/acquis.d" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/scenarios" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/postoverflows" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/collections" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/patterns" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-configs" || exit mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-rules" || exit + mkdir -p "${CROWDSEC_CONFIG_PATH}/contexts" || exit mkdir -p "${CROWDSEC_CONSOLE_DIR}" || exit # tmp @@ -523,6 +526,7 @@ install_plugins(){ cp ${HTTP_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} cp ${EMAIL_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} cp ${SENTINEL_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} + cp ${FILE_PLUGIN_BINARY} ${CROWDSEC_PLUGIN_DIR} if [[ ${DOCKER_MODE} == "false" ]]; then cp -n ${SLACK_PLUGIN_CONFIG} /etc/crowdsec/notifications/ @@ -530,6 +534,7 @@ install_plugins(){ cp -n ${HTTP_PLUGIN_CONFIG} /etc/crowdsec/notifications/ cp -n ${EMAIL_PLUGIN_CONFIG} /etc/crowdsec/notifications/ cp -n ${SENTINEL_PLUGIN_CONFIG} /etc/crowdsec/notifications/ + cp -n ${FILE_PLUGIN_CONFIG} /etc/crowdsec/notifications/ fi }