diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 8702b7235..547037c5a 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,13 @@ body: attributes: value: | ### Thank you for taking the time to file a bug report! - + Please fill out this form as completely as possible. - type: input id: version attributes: - label: What version of `nebula` are you using? + label: What version of `nebula` are you using? (`nebula -version`) placeholder: 0.0.0 validations: required: true @@ -41,10 +41,17 @@ body: attributes: label: Logs from affected hosts description: | - Provide logs from all affected hosts during the time of the issue. + Please provide logs from ALL affected hosts during the time of the issue. If you do not provide logs we will be unable to assist you! + + [Learn how to find Nebula logs here.](https://nebula.defined.net/docs/guides/viewing-nebula-logs/) + Improve formatting by using ``` at the beginning and end of each log block. + value: | + ``` + + ``` validations: - required: false + required: true - type: textarea id: configs @@ -52,6 +59,11 @@ body: label: Config files from affected hosts description: | Provide config files for all affected hosts. + Improve formatting by using ``` at the beginning and end of each config file. + value: | + ``` + + ``` validations: - required: false + required: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..abf74a0fa --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,22 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" + groups: + golang-x-dependencies: + patterns: + - "golang.org/x/*" + zx2c4-dependencies: + patterns: + - "golang.zx2c4.com/*" + protobuf-dependencies: + patterns: + - "github.com/golang/protobuf" + - "google.golang.org/protobuf" diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index a00453bfe..e0d41aec9 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -14,31 +14,21 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions/cache@v2 + - uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-gofmt1.19-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-gofmt1.19- + go-version: '1.22' + check-latest: true - name: Install goimports run: | - go get golang.org/x/tools/cmd/goimports - go build golang.org/x/tools/cmd/goimports + go install golang.org/x/tools/cmd/goimports@latest - name: gofmt run: | - if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ] + if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] then - find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d + find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d exit 1 fi diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 572b0ffd4..a199f1d4d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,25 +7,24 @@ name: Create release and upload binaries jobs: build-linux: - name: Build Linux All + name: Build Linux/BSD All runs-on: ubuntu-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + check-latest: true - name: Build run: | - make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd + make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd release-openbsd release-netbsd mkdir release mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: linux-latest path: release @@ -34,13 +33,12 @@ jobs: name: Build Windows runs-on: windows-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + check-latest: true - name: Build run: | @@ -57,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: windows-latest path: build @@ -68,17 +66,16 @@ jobs: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} runs-on: macos-11 steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + check-latest: true - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v1 + uses: Apple-Actions/import-codesign-certs@v2 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} @@ -107,7 +104,7 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: darwin-latest path: ./release/* @@ -117,12 +114,16 @@ jobs: needs: [build-linux, build-darwin, build-windows] runs-on: ubuntu-latest steps: + - uses: actions/checkout@v4 + - name: Download artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 + with: + path: artifacts - name: Zip Windows run: | - cd windows-latest + cd artifacts/windows-latest cp windows-amd64/* . zip -r nebula-windows-amd64.zip nebula.exe nebula-cert.exe dist cp windows-arm64/* . @@ -130,6 +131,7 @@ jobs: - name: Create sha256sum run: | + cd artifacts for dir in linux-latest darwin-latest windows-latest do ( @@ -159,195 +161,12 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ github.ref }} - release_name: Release ${{ github.ref }} - draft: false - prerelease: false - - ## - ## Upload assets (I wish we could just upload the whole folder at once... - ## - - - name: Upload SHASUM256.txt - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./SHASUM256.txt - asset_name: SHASUM256.txt - asset_content_type: text/plain - - - name: Upload darwin zip - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./darwin-latest/nebula-darwin.zip - asset_name: nebula-darwin.zip - asset_content_type: application/zip - - - name: Upload windows-amd64 - uses: actions/upload-release-asset@v1.0.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./windows-latest/nebula-windows-amd64.zip - asset_name: nebula-windows-amd64.zip - asset_content_type: application/zip - - - name: Upload windows-arm64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./windows-latest/nebula-windows-arm64.zip - asset_name: nebula-windows-arm64.zip - asset_content_type: application/zip - - - name: Upload linux-amd64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-amd64.tar.gz - asset_name: nebula-linux-amd64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-386 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-386.tar.gz - asset_name: nebula-linux-386.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-ppc64le - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-ppc64le.tar.gz - asset_name: nebula-linux-ppc64le.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-5 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-5.tar.gz - asset_name: nebula-linux-arm-5.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-6 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-6.tar.gz - asset_name: nebula-linux-arm-6.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-7 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-7.tar.gz - asset_name: nebula-linux-arm-7.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm64.tar.gz - asset_name: nebula-linux-arm64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips.tar.gz - asset_name: nebula-linux-mips.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mipsle - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mipsle.tar.gz - asset_name: nebula-linux-mipsle.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips64.tar.gz - asset_name: nebula-linux-mips64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips64le - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz - asset_name: nebula-linux-mips64le.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips-softfloat - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz - asset_name: nebula-linux-mips-softfloat.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-riscv64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-riscv64.tar.gz - asset_name: nebula-linux-riscv64.tar.gz - asset_content_type: application/gzip - - - name: Upload freebsd-amd64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz - asset_name: nebula-freebsd-amd64.tar.gz - asset_content_type: application/gzip + run: | + cd artifacts + gh release create \ + --verify-tag \ + --title "Release ${{ github.ref_name }}" \ + "${{ github.ref_name }}" \ + SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 162d52665..54833bdc7 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -18,24 +18,15 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions/cache@v2 + - uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.19- + go-version: '1.22' + check-latest: true - name: build - run: make bin-docker + run: make bin-docker CGO_ENABLED=1 BUILD_ARGS=-race - name: setup docker image working-directory: ./.github/workflows/smoke @@ -53,4 +44,12 @@ jobs: working-directory: ./.github/workflows/smoke run: ./smoke-relay.sh + - name: setup docker image for P256 + working-directory: ./.github/workflows/smoke + run: NAME="smoke-p256" CURVE=P256 ./build.sh + + - name: run smoke-p256 + working-directory: ./.github/workflows/smoke + run: NAME="smoke-p256" ./smoke.sh + timeout-minutes: 10 diff --git a/.github/workflows/smoke/Dockerfile b/.github/workflows/smoke/Dockerfile index 18460b32a..f8a89ef59 100644 --- a/.github/workflows/smoke/Dockerfile +++ b/.github/workflows/smoke/Dockerfile @@ -1,4 +1,6 @@ -FROM debian:buster +FROM ubuntu:jammy + +RUN apt-get update && apt-get install -y iputils-ping ncat tcpdump ADD ./build /nebula diff --git a/.github/workflows/smoke/build-relay.sh b/.github/workflows/smoke/build-relay.sh index 1ec23c775..70b07f4e3 100755 --- a/.github/workflows/smoke/build-relay.sh +++ b/.github/workflows/smoke/build-relay.sh @@ -41,4 +41,4 @@ EOF ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24" ) -sudo docker build -t nebula:smoke-relay . +docker build -t nebula:smoke-relay . diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index 0c20b3f84..9cbb20058 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -29,11 +29,11 @@ mkdir ./build OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml - ../../../../nebula-cert ca -name "Smoke Test" + ../../../../nebula-cert ca -curve "${CURVE:-25519}" -name "Smoke Test" ../../../../nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24" ../../../../nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24" ../../../../nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24" ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24" ) -sudo docker build -t nebula:smoke . +docker build -t "nebula:${NAME:-smoke}" . diff --git a/.github/workflows/smoke/genconfig.sh b/.github/workflows/smoke/genconfig.sh index 005734cce..373ea5fca 100755 --- a/.github/workflows/smoke/genconfig.sh +++ b/.github/workflows/smoke/genconfig.sh @@ -50,6 +50,8 @@ tun: dev: ${TUN_DEV:-nebula1} firewall: + inbound_action: reject + outbound_action: reject outbound: ${OUTBOUND:-$FIREWALL_ALL} inbound: ${INBOUND:-$FIREWALL_ALL} diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 91954d627..9c113e185 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -14,24 +14,24 @@ cleanup() { set +e if [ "$(jobs -r)" ] then - sudo docker kill lighthouse1 host2 host3 host4 + docker kill lighthouse1 host2 host3 host4 fi } trap cleanup EXIT -sudo docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test -sudo docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test -sudo docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test -sudo docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test +docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test +docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test +docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test +docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test -sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x @@ -39,44 +39,44 @@ echo echo " *** Testing ping from lighthouse1" echo set -x -sudo docker exec lighthouse1 ping -c1 192.168.100.2 -sudo docker exec lighthouse1 ping -c1 192.168.100.3 -sudo docker exec lighthouse1 ping -c1 192.168.100.4 +docker exec lighthouse1 ping -c1 192.168.100.2 +docker exec lighthouse1 ping -c1 192.168.100.3 +docker exec lighthouse1 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host2" echo set -x -sudo docker exec host2 ping -c1 192.168.100.1 +docker exec host2 ping -c1 192.168.100.1 # Should fail because no relay configured in this direction -! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 -! sudo docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x -sudo docker exec host3 ping -c1 192.168.100.1 -sudo docker exec host3 ping -c1 192.168.100.2 -sudo docker exec host3 ping -c1 192.168.100.4 +docker exec host3 ping -c1 192.168.100.1 +docker exec host3 ping -c1 192.168.100.2 +docker exec host3 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host4" echo set -x -sudo docker exec host4 ping -c1 192.168.100.1 +docker exec host4 ping -c1 192.168.100.1 # Should fail because relays not allowed -! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 -sudo docker exec host4 ping -c1 192.168.100.3 +! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 +docker exec host4 ping -c1 192.168.100.3 -sudo docker exec host4 sh -c 'kill 1' -sudo docker exec host3 sh -c 'kill 1' -sudo docker exec host2 sh -c 'kill 1' -sudo docker exec lighthouse1 sh -c 'kill 1' -sleep 1 +docker exec host4 sh -c 'kill 1' +docker exec host3 sh -c 'kill 1' +docker exec host2 sh -c 'kill 1' +docker exec lighthouse1 sh -c 'kill 1' +sleep 5 if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 213add30e..6d04027aa 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -14,60 +14,105 @@ cleanup() { set +e if [ "$(jobs -r)" ] then - sudo docker kill lighthouse1 host2 host3 host4 + docker kill lighthouse1 host2 host3 host4 fi } trap cleanup EXIT -sudo docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test -sudo docker run --name host2 --rm nebula:smoke -config host2.yml -test -sudo docker run --name host3 --rm nebula:smoke -config host3.yml -test -sudo docker run --name host4 --rm nebula:smoke -config host4.yml -test +CONTAINER="nebula:${NAME:-smoke}" -sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test +docker run --name host2 --rm "$CONTAINER" -config host2.yml -test +docker run --name host3 --rm "$CONTAINER" -config host3.yml -test +docker run --name host4 --rm "$CONTAINER" -config host4.yml -test + +docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 +# grab tcpdump pcaps for debugging +docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & +docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & +docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & +docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & +docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & +docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & + +docker exec host2 ncat -nklv 0.0.0.0 2000 & +docker exec host3 ncat -nklv 0.0.0.0 2000 & +docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & +docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & + set +x echo echo " *** Testing ping from lighthouse1" echo set -x -sudo docker exec lighthouse1 ping -c1 192.168.100.2 -sudo docker exec lighthouse1 ping -c1 192.168.100.3 +docker exec lighthouse1 ping -c1 192.168.100.2 +docker exec lighthouse1 ping -c1 192.168.100.3 set +x echo echo " *** Testing ping from host2" echo set -x -sudo docker exec host2 ping -c1 192.168.100.1 +docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall -! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 + +set +x +echo +echo " *** Testing ncat from host2" +echo +set -x +# Should fail because not allowed by host3 inbound firewall +! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 +! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x -sudo docker exec host3 ping -c1 192.168.100.1 -sudo docker exec host3 ping -c1 192.168.100.2 +docker exec host3 ping -c1 192.168.100.1 +docker exec host3 ping -c1 192.168.100.2 + +set +x +echo +echo " *** Testing ncat from host3" +echo +set -x +docker exec host3 ncat -nzv -w5 192.168.100.2 2000 +docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 set +x echo echo " *** Testing ping from host4" echo set -x -sudo docker exec host4 ping -c1 192.168.100.1 +docker exec host4 ping -c1 192.168.100.1 +# Should fail because not allowed by host4 outbound firewall +! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 +! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1 + +set +x +echo +echo " *** Testing ncat from host4" +echo +set -x # Should fail because not allowed by host4 outbound firewall -! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 -! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1 +! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1 +! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1 +! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo @@ -75,16 +120,16 @@ echo " *** Testing conntrack" echo set -x # host2 can ping host3 now that host3 pinged it first -sudo docker exec host2 ping -c1 192.168.100.3 +docker exec host2 ping -c1 192.168.100.3 # host4 can ping host2 once conntrack established -sudo docker exec host2 ping -c1 192.168.100.4 -sudo docker exec host4 ping -c1 192.168.100.2 +docker exec host2 ping -c1 192.168.100.4 +docker exec host4 ping -c1 192.168.100.2 -sudo docker exec host4 sh -c 'kill 1' -sudo docker exec host3 sh -c 'kill 1' -sudo docker exec host2 sh -c 'kill 1' -sudo docker exec lighthouse1 sh -c 'kill 1' -sleep 1 +docker exec host4 sh -c 'kill 1' +docker exec host3 sh -c 'kill 1' +docker exec host2 sh -c 'kill 1' +docker exec lighthouse1 sh -c 'kill 1' +sleep 5 if [ "$(jobs -r)" ] then diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 69ed606d4..d71262aef 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,37 +18,55 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions/cache@v2 + - uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.19- + go-version: '1.22' + check-latest: true - name: Build run: make all + - name: Vet + run: make vet + - name: Test run: make test - name: End 2 end run: make e2evv + - name: Build test mobile + run: make build-test-mobile + - uses: actions/upload-artifact@v3 with: name: e2e packet flow path: e2e/mermaid/ if-no-files-found: warn + test-linux-boringcrypto: + name: Build and test on linux with boringcrypto + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + check-latest: true + + - name: Build + run: make bin-boringcrypto + + - name: Test + run: make test-boringcrypto + + - name: End 2 end + run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1 + test: name: Build and test on ${{ matrix.os }} runs-on: ${{ matrix.os }} @@ -57,21 +75,12 @@ jobs: os: [windows-latest, macos-11] steps: - - name: Set up Go 1.19 - uses: actions/setup-go@v2 - with: - go-version: 1.19 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions/cache@v2 + - uses: actions/setup-go@v5 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.19- + go-version: '1.22' + check-latest: true - name: Build nebula run: go build ./cmd/nebula @@ -79,8 +88,11 @@ jobs: - name: Build nebula-cert run: go build ./cmd/nebula-cert + - name: Vet + run: make vet + - name: Test - run: go test -v ./... + run: make test - name: End 2 end run: make e2evv diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f0ebbb20..71c3ed47b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,166 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.8.2] - 2024-01-08 + +### Fixed + +- Fix multiple routines when listen.port is zero. This was a regression + introduced in v1.6.0. (#1057) + +### Changed + +- Small dependency update for Noise. (#1038) + +## [1.8.1] - 2023-12-19 + +### Security + +- Update `golang.org/x/crypto`, which includes a fix for CVE-2023-48795. (#1048) + +### Fixed + +- Fix a deadlock introduced in v1.8.0 that could occur during handshakes. (#1044) + +- Fix mobile builds. (#1035) + +## [1.8.0] - 2023-12-06 + +### Deprecated + +- The next minor release of Nebula, 1.9.0, will require at least Windows 10 or + Windows Server 2016. This is because support for earlier versions was removed + in Go 1.21. See https://go.dev/doc/go1.21#windows + +### Added + +- Linux: Notify systemd of service readiness. This should resolve timing issues + with services that depend on Nebula being active. For an example of how to + enable this, see: `examples/service_scripts/nebula.service`. (#929) + +- Windows: Use Registered IO (RIO) when possible. Testing on a Windows 11 + machine shows ~50x improvement in throughput. (#905) + +- NetBSD, OpenBSD: Added rudimentary support. (#916, #812) + +- FreeBSD: Add support for naming tun devices. (#903) + +### Changed + +- `pki.disconnect_invalid` will now default to true. This means that once a + certificate expires, the tunnel will be disconnected. If you use SIGHUP to + reload certificates without restarting Nebula, you should ensure all of your + clients are on 1.7.0 or newer before you enable this feature. (#859) + +- Limit how often a busy tunnel can requery the lighthouse. The new config + option `timers.requery_wait_duration` defaults to `60s`. (#940) + +- The internal structures for hostmaps were refactored to reduce memory usage + and the potential for subtle bugs. (#843, #938, #953, #954, #955) + +- Lots of dependency updates. + +### Fixed + +- Windows: Retry wintun device creation if it fails the first time. (#985) + +- Fix issues with firewall reject packets that could cause panics. (#957) + +- Fix relay migration during re-handshakes. (#964) + +- Various other refactors and fixes. (#935, #952, #972, #961, #996, #1002, + #987, #1004, #1030, #1032, ...) + +## [1.7.2] - 2023-06-01 + +### Fixed + +- Fix a freeze during config reload if the `static_host_map` config was changed. (#886) + +## [1.7.1] - 2023-05-18 + +### Fixed + +- Fix IPv4 addresses returned by `static_host_map` DNS lookup queries being + treated as IPv6 addresses. (#877) + +## [1.7.0] - 2023-05-17 + +### Added + +- `nebula-cert ca` now supports encrypting the CA's private key with a + passphrase. Pass `-encrypt` in order to be prompted for a passphrase. + Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF + parameters default to RFC recommendations, but can be overridden via CLI + flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`. (#386) + +- Support for curve P256 and BoringCrypto has been added. See README section + "Curve P256 and BoringCrypto" for more details. (#865, #861, #769, #856, #803) + +- New firewall rule `local_cidr`. This could be used to filter destinations + when using `unsafe_routes`. (#507) + +- Add `unsafe_route` option `install`. This controls whether the route is + installed in the systems routing table. (#831) + +- Add `tun.use_system_route_table` option. Set to true to manage unsafe routes + directly on the system route table with gateway routes instead of in Nebula + configuration files. This is only supported on Linux. (#839) + +- The metric `certificate.ttl_seconds` is now exposed via stats. (#782) + +- Add `punchy.respond_delay` option. This allows you to change the delay + before attempting punchy.respond. Default is 5 seconds. (#721) + +- Added SSH commands to allow the capture of a mutex profile. (#737) + +- You can now set `lighthouse.calculated_remotes` to make it possible to do + handshakes without a lighthouse in certain configurations. (#759) + +- The firewall can be configured to send REJECT replies instead of the default + DROP behavior. (#738) + +- For macOS, an example launchd configuration file is now provided. (#762) + +### Changed + +- Lighthouses and other `static_host_map` entries that use DNS names will now + be automatically refreshed to detect when the IP address changes. (#796) + +- Lighthouses send ACK replies back to clients so that they do not fall into + connection testing as often by clients. (#851, #408) + +- Allow the `listen.host` option to contain a hostname. (#825) + +- When Nebula switches to a new certificate (such as via SIGHUP), we now + rehandshake with all existing tunnels. This allows firewall groups to be + updated and `pki.disconnect_invalid` to know about the new certificate + expiration time. (#838, #857, #842, #840, #835, #828, #820, #807) + +### Fixed + +- Always disconnect blocklisted hosts, even if `pki.disconnect_invalid` is + not set. (#858) + +- Dependencies updated and go1.20 required. (#780, #824, #855, #854) + +- Fix possible race condition with relays. (#827) + +- FreeBSD: Fix connection to the localhost's own Nebula IP. (#808) + +- Normalize and document some common log field values. (#837, #811) + +- Fix crash if you set unlucky values for the firewall timeout configuration + options. (#802) + +- Make DNS queries case insensitive. (#793) + +- Update example systemd configurations to want `nss-lookup`. (#791) + +- Errors with SSH commands now go to the SSH tunnel instead of stderr. (#757) + +- Fix a hang when shutting down Android. (#772) + ## [1.6.1] - 2022-09-26 ### Fixed @@ -398,7 +558,13 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.6.1...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.2...HEAD +[1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 +[1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 +[1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0 +[1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2 +[1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1 +[1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0 [1.6.1]: https://github.com/slackhq/nebula/releases/tag/v1.6.1 [1.6.0]: https://github.com/slackhq/nebula/releases/tag/v1.6.0 [1.5.2]: https://github.com/slackhq/nebula/releases/tag/v1.5.2 diff --git a/LOGGING.md b/LOGGING.md new file mode 100644 index 000000000..e2508c83c --- /dev/null +++ b/LOGGING.md @@ -0,0 +1,37 @@ +### Logging conventions + +A log message (the string/format passed to `Info`, `Error`, `Debug` etc, as well as their `Sprintf` counterparts) should +be a descriptive message about the event and may contain specific identifying characteristics. Regardless of the +level of detail in the message identifying characteristics should always be included via `WithField`, `WithFields` or +`WithError` + +If an error is being logged use `l.WithError(err)` so that there is better discoverability about the event as well +as the specific error condition. + +#### Common fields + +- `cert` - a `cert.NebulaCertificate` object, do not `.String()` this manually, `logrus` will marshal objects properly + for the formatter it is using. +- `fingerprint` - a single `NebeulaCertificate` hex encoded fingerprint +- `fingerprints` - an array of `NebulaCertificate` hex encoded fingerprints +- `fwPacket` - a FirewallPacket object +- `handshake` - an object containing: + - `stage` - the current stage counter + - `style` - noise handshake style `ix_psk0`, `xx`, etc +- `header` - a nebula header object +- `udpAddr` - a `net.UDPAddr` object +- `udpIp` - a udp ip address +- `vpnIp` - vpn ip of the host (remote or local) +- `relay` - the vpnIp of the relay host that is or should be handling the relay packet +- `relayFrom` - The vpnIp of the initial sender of the relayed packet +- `relayTo` - The vpnIp of the final destination of a relayed packet + +#### Example: + +``` +l.WithError(err). + WithField("vpnIp", IntIp(hostinfo.hostId)). + WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix"}). + Info("Invalid certificate from host") +``` \ No newline at end of file diff --git a/Makefile b/Makefile index b31c0fcc1..4d0a9e580 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,14 @@ -GOMINVERSION = 1.19 NEBULA_CMD_PATH = "./cmd/nebula" -GO111MODULE = on -export GO111MODULE CGO_ENABLED = 0 export CGO_ENABLED # Set up OS specific bits ifeq ($(OS),Windows_NT) - #TODO: we should be able to ditch awk as well - GOVERSION := $(shell go version | awk "{print substr($$3, 3)}") - GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1) NEBULA_CMD_SUFFIX = .exe NULL_FILE = nul + # RIO on windows does pointer stuff that makes go vet angry + VET_FLAGS = -unsafeptr=false else - GOVERSION := $(shell go version | awk '{print substr($$3, 3)}') - GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)") NEBULA_CMD_SUFFIX = NULL_FILE = /dev/null endif @@ -44,10 +38,21 @@ ALL_LINUX = linux-amd64 \ linux-mips-softfloat \ linux-riscv64 +ALL_FREEBSD = freebsd-amd64 \ + freebsd-arm64 + +ALL_OPENBSD = openbsd-amd64 \ + openbsd-arm64 + +ALL_NETBSD = netbsd-amd64 \ + netbsd-arm64 + ALL = $(ALL_LINUX) \ + $(ALL_FREEBSD) \ + $(ALL_OPENBSD) \ + $(ALL_NETBSD) \ darwin-amd64 \ darwin-arm64 \ - freebsd-amd64 \ windows-amd64 \ windows-arm64 @@ -75,7 +80,13 @@ release: $(ALL:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) -release-freebsd: build/nebula-freebsd-amd64.tar.gz +release-freebsd: $(ALL_FREEBSD:%=build/nebula-%.tar.gz) + +release-openbsd: $(ALL_OPENBSD:%=build/nebula-%.tar.gz) + +release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz) + +release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz BUILD_ARGS = -trimpath @@ -91,6 +102,12 @@ bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert mv $? . +bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert + mv $? . + +bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert + mv $? . + bin: go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH} go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert @@ -105,6 +122,10 @@ build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*)) # Build an extra small binary for mips-softfloat build/linux-mips-softfloat/%: LDFLAGS += -s -w +# boringcrypto +build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 +build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 + build/%/nebula: .FORCE GOOS=$(firstword $(subst -, , $*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \ @@ -128,15 +149,24 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe vet: - go vet -v ./... + go vet $(VET_FLAGS) -v ./... test: go test -v ./... +test-boringcrypto: + GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./... + test-cov-html: go test -coverprofile=coverage.out go tool cover -html=coverage.out +build-test-mobile: + GOARCH=amd64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=arm64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=amd64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=arm64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + bench: go test -bench=. @@ -170,14 +200,17 @@ bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert smoke-docker: bin-docker cd .github/workflows/smoke/ && ./build.sh cd .github/workflows/smoke/ && ./smoke.sh + cd .github/workflows/smoke/ && NAME="smoke-p256" CURVE="P256" ./build.sh + cd .github/workflows/smoke/ && NAME="smoke-p256" ./smoke.sh smoke-relay-docker: bin-docker cd .github/workflows/smoke/ && ./build-relay.sh cd .github/workflows/smoke/ && ./smoke-relay.sh smoke-docker-race: BUILD_ARGS = -race +smoke-docker-race: CGO_ENABLED = 1 smoke-docker-race: smoke-docker .FORCE: -.PHONY: e2e e2ev e2evv e2evvv e2evvvv test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race +.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html .DEFAULT_GOAL := bin diff --git a/README.md b/README.md index ba4e99742..51e913d5d 100644 --- a/README.md +++ b/README.md @@ -27,15 +27,26 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for #### Distribution Packages -- [Arch Linux](https://archlinux.org/packages/community/x86_64/nebula/) +- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/) ``` $ sudo pacman -S nebula ``` + - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula) ``` $ sudo dnf install nebula ``` +- [Debian Linux](https://packages.debian.org/source/stable/nebula) + ``` + $ sudo apt install nebula + ``` + +- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula) + ``` + $ sudo apk add nebula + ``` + - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb) ``` $ brew install nebula @@ -108,7 +119,7 @@ For each host, copy the nebula binary to the host, along with `config.yml` from ## Building Nebula from source -Download go and clone this repo. Change to the nebula directory. +Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory. To build nebula for all platforms: `make all` @@ -118,6 +129,17 @@ To build nebula for a specific platform (ex, Windows): See the [Makefile](Makefile) for more details on build targets +## Curve P256 and BoringCrypto + +The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes. + +In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets: + + make bin-boringcrypto + make release-boringcrypto + +This is not the recommended default deployment, but may be useful based on your compliance requirements. + ## Credits Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..bfff621a7 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +Security Policy +=============== + +Reporting a Vulnerability +------------------------- + +If you believe you have found a security vulnerability with Nebula, please let +us know right away. We will investigate all reports and do our best to quickly +fix valid issues. + +You can submit your report on [HackerOne](https://hackerone.com/slack) and our +security team will respond as soon as possible. diff --git a/allow_list.go b/allow_list.go index 0e44a126e..9186b2fc7 100644 --- a/allow_list.go +++ b/allow_list.go @@ -12,7 +12,7 @@ import ( type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6 + cidrTree *cidr.Tree6[bool] } type RemoteAllowList struct { @@ -20,7 +20,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6 + insideAllowLists *cidr.Tree6[*AllowList] } type LocalAllowList struct { @@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) { +func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6() + remoteAllowRanges := cidr.NewTree6[*AllowList]() rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool { return true } - result := al.cidrTree.MostSpecificContains(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContains(ip) + return result } func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { @@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV4(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV4(ip) + return result } func (al *AllowList) AllowIpV6(hi, lo uint64) bool { @@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + return result } func (al *LocalAllowList) Allow(ip net.IP) bool { @@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { if al.insideAllowLists != nil { - inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) - if inside != nil { - return inside.(*AllowList) + ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + if ok { + return inside } } return nil diff --git a/allow_list_test.go b/allow_list_test.go index 991b8a319..334cb6062 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) { func TestAllowList_Allow(t *testing.T) { assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) diff --git a/boring.go b/boring.go new file mode 100644 index 000000000..9cd9d37f2 --- /dev/null +++ b/boring.go @@ -0,0 +1,8 @@ +//go:build boringcrypto +// +build boringcrypto + +package nebula + +import "crypto/boring" + +var boringEnabled = boring.Enabled diff --git a/calculated_remote.go b/calculated_remote.go new file mode 100644 index 000000000..38f5bea25 --- /dev/null +++ b/calculated_remote.go @@ -0,0 +1,143 @@ +package nebula + +import ( + "fmt" + "math" + "net" + "strconv" + + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +// This allows us to "guess" what the remote might be for a host while we wait +// for the lighthouse response. See "lighthouse.calculated_remotes" in the +// example config file. +type calculatedRemote struct { + ipNet net.IPNet + maskIP iputil.VpnIp + mask iputil.VpnIp + port uint32 +} + +func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { + // Ensure this is an IPv4 mask that we expect + ones, bits := ipNet.Mask.Size() + if ones == 0 || bits != 32 { + return nil, fmt.Errorf("invalid mask: %v", ipNet) + } + if port < 0 || port > math.MaxUint16 { + return nil, fmt.Errorf("invalid port: %d", port) + } + + return &calculatedRemote{ + ipNet: *ipNet, + maskIP: iputil.Ip2VpnIp(ipNet.IP), + mask: iputil.Ip2VpnIp(ipNet.Mask), + port: uint32(port), + }, nil +} + +func (c *calculatedRemote) String() string { + return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) +} + +func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes + // of the overlay IP + masked := (c.maskIP & c.mask) | (ip & ^c.mask) + + return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +} + +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + + rawMap, ok := value.(map[any]any) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + _, ipNet, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + entry, err := newCalculatedRemotesListFromConfig(rawValue) + if err != nil { + return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) + } + + calculatedRemotes.AddCIDR(ipNet, entry) + } + + return calculatedRemotes, nil +} + +func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { + rawList, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) + } + + var l []*calculatedRemote + for _, e := range rawList { + c, err := newCalculatedRemotesEntryFromConfig(e) + if err != nil { + return nil, fmt.Errorf("calculated_remotes entry: %w", err) + } + l = append(l, c) + } + + return l, nil +} + +func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { + rawMap, ok := raw.(map[any]any) + if !ok { + return nil, fmt.Errorf("invalid type: %T", raw) + } + + rawValue := rawMap["mask"] + if rawValue == nil { + return nil, fmt.Errorf("missing mask: %v", rawMap) + } + rawMask, ok := rawValue.(string) + if !ok { + return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) + } + _, ipNet, err := net.ParseCIDR(rawMask) + if err != nil { + return nil, fmt.Errorf("invalid mask: %s", rawMask) + } + + var port int + rawValue = rawMap["port"] + if rawValue == nil { + return nil, fmt.Errorf("missing port: %v", rawMap) + } + switch v := rawValue.(type) { + case int: + port = v + case string: + port, err = strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid port: %s: %w", v, err) + } + default: + return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) + } + + return newCalculatedRemote(ipNet, port) +} diff --git a/calculated_remote_test.go b/calculated_remote_test.go new file mode 100644 index 000000000..2ddebca74 --- /dev/null +++ b/calculated_remote_test.go @@ -0,0 +1,27 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/iputil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCalculatedRemoteApply(t *testing.T) { + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + require.NoError(t, err) + + c, err := newCalculatedRemote(ipNet, 4242) + require.NoError(t, err) + + input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + + expected := &Ip4AndPort{ + Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), + Port: 4242, + } + + assert.Equal(t, expected, c.Apply(input)) +} diff --git a/cert.go b/cert.go deleted file mode 100644 index be7bb6a4e..000000000 --- a/cert.go +++ /dev/null @@ -1,163 +0,0 @@ -package nebula - -import ( - "errors" - "fmt" - "io/ioutil" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" -) - -type CertState struct { - certificate *cert.NebulaCertificate - rawCertificate []byte - rawCertificateNoKey []byte - publicKey []byte - privateKey []byte -} - -func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) - } - - publicKey := certificate.Details.PublicKey - cs := &CertState{ - rawCertificate: rawCertificate, - certificate: certificate, // PublicKey has been set to nil above - privateKey: privateKey, - publicKey: publicKey, - } - - cs.certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) - } - cs.rawCertificateNoKey = rawCertNoKey - // put public key back - cs.certificate.Details.PublicKey = cs.publicKey - return cs, nil -} - -func NewCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte - var err error - - privPathOrPEM := c.GetString("pki.key", "") - - if privPathOrPEM == "" { - return nil, errors.New("no pki.key path or PEM data provided") - } - - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - } else { - pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, err := cert.UnmarshalX25519PrivateKey(pemPrivateKey) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - - var rawCert []byte - - pubPathOrPEM := c.GetString("pki.cert", "") - - if pubPathOrPEM == "" { - return nil, errors.New("no pki.cert path or PEM data provided") - } - - if strings.Contains(pubPathOrPEM, "-----BEGIN") { - rawCert = []byte(pubPathOrPEM) - pubPathOrPEM = "" - } else { - rawCert, err = ioutil.ReadFile(pubPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) - } - } - - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) - } - - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") - } - - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") - } - - if err = nebulaCert.VerifyPrivateKey(rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") - } - - return NewCertState(nebulaCert, rawKey) -} - -func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { - var rawCA []byte - var err error - - caPathOrPEM := c.GetString("pki.ca", "") - if caPathOrPEM == "" { - return nil, errors.New("no pki.ca path or PEM data provided") - } - - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) - - } else { - rawCA, err = ioutil.ReadFile(caPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) - } - } - - CAs, err := cert.NewCAPoolFromBytes(rawCA) - if errors.Is(err, cert.ErrExpired) { - var expired int - for _, cert := range CAs.CAs { - if cert.Expired(time.Now()) { - expired++ - l.WithField("cert", cert).Warn("expired certificate present in CA pool") - } - } - - if expired >= len(CAs.CAs) { - return nil, errors.New("no valid CA certificates present") - } - - } else if err != nil { - return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) - } - - for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - CAs.BlocklistFingerprint(fp) - } - - // Support deprecated config for at least one minor release to allow for migrations - //TODO: remove in 2022 or later - for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist") - CAs.BlocklistFingerprint(fp) - } - - return CAs, nil -} diff --git a/cert/ca.go b/cert/ca.go index d7005441f..0ffbd8792 100644 --- a/cert/ca.go +++ b/cert/ca.go @@ -91,9 +91,15 @@ func (ncp *NebulaCAPool) ResetCertBlocklist() { ncp.certBlocklist = make(map[string]struct{}) } -// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted +// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated +// automatically if you manually change any fields in the NebulaCertificate. func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool { - h, err := c.Sha256Sum() + return ncp.isBlocklistedWithCache(c, false) +} + +// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted +func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool { + h, err := c.sha256SumWithCache(useCache) if err != nil { return true } diff --git a/cert/cert.go b/cert/cert.go index f3df89c1a..4f1b776c0 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -2,35 +2,55 @@ package cert import ( "bytes" - "crypto" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/pem" + "errors" "fmt" + "math" + "math/big" "net" + "sync/atomic" "time" "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" "google.golang.org/protobuf/proto" ) const publicKeyLen = 32 const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + CertBanner = "NEBULA CERTIFICATE" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + + P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" + P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" + EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" + ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" ) type NebulaCertificate struct { Details NebulaCertificateDetails Signature []byte + + // the cached hex string of the calculated sha256sum + // for VerifyWithCache + sha256sum atomic.Pointer[string] + + // the cached public key bytes if they were verified as the signer + // for VerifyWithCache + signatureVerified atomic.Pointer[[]byte] } type NebulaCertificateDetails struct { @@ -46,10 +66,25 @@ type NebulaCertificateDetails struct { // Map of groups for faster lookup InvertedGroups map[string]struct{} + + Curve Curve +} + +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters } type m map[string]interface{} +// Returned if we try to unmarshal an encrypted private key without a passphrase +var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + // UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { if len(b) == 0 { @@ -84,6 +119,7 @@ func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { PublicKey: make([]byte, len(rc.Details.PublicKey)), IsCA: rc.Details.IsCA, InvertedGroups: make(map[string]struct{}), + Curve: rc.Details.Curve, }, Signature: make([]byte, len(rc.Signature)), } @@ -134,6 +170,28 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er return nc, r, err } +func MarshalPrivateKey(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +func MarshalSigningPrivateKey(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + // MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key func MarshalX25519PrivateKey(b []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) @@ -144,6 +202,90 @@ func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) } +func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PrivateKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PrivateKeyBanner: + expectedLen = 32 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var curve Curve + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted + case EncryptedECDSAP256PrivateKeyBanner: + return nil, nil, Curve_P256, ErrPrivateKeyEncrypted + case Ed25519PrivateKeyBanner: + curve = Curve_CURVE25519 + if len(k.Bytes) != ed25519.PrivateKeySize { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) + } + case ECDSAP256PrivateKeyBanner: + curve = Curve_P256 + if len(k.Bytes) != 32 { + return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner") + } + return k.Bytes, r, curve, nil +} + +// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key +func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + if err != nil { + return nil, err + } + + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil + default: + return nil, fmt.Errorf("invalid curve: %v", curve) + } +} + // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b // or an error on failure func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { @@ -168,9 +310,13 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { if k == nil { return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") } - if k.Type != Ed25519PrivateKeyBanner { + + if k.Type == EncryptedEd25519PrivateKeyBanner { + return nil, r, ErrPrivateKeyEncrypted + } else if k.Type != Ed25519PrivateKeyBanner { return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") } + if len(k.Bytes) != ed25519.PrivateKeySize { return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } @@ -178,6 +324,126 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { return k.Bytes, r, nil } +// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: rune(params.Version), + Memory: uint32(params.Memory), + Parallelism: uint8(params.Parallelism), + Iterations: uint32(params.Iterations), + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { + var curve Curve + + k, r := pem.Decode(b) + if k == nil { + return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + curve = Curve_CURVE25519 + case EncryptedECDSAP256PrivateKeyBanner: + curve = Curve_P256 + default: + return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return curve, nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return curve, nil, r, err + } + default: + return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + switch curve { + case Curve_CURVE25519: + if len(bytes) != ed25519.PrivateKeySize { + return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) + } + case Curve_P256: + if len(bytes) != 32 { + return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + } + + return curve, bytes, r, nil +} + +func MarshalPublicKey(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) + default: + return nil + } +} + // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key func MarshalX25519PublicKey(b []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) @@ -188,6 +454,30 @@ func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte { return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key}) } +func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PublicKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PublicKeyBanner: + // Uncompressed + expectedLen = 65 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + // UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b // or an error on failure func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) { @@ -223,27 +513,86 @@ func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) { } // Sign signs a nebula cert with the provided private key -func (nc *NebulaCertificate) Sign(key ed25519.PrivateKey) error { +func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { + if curve != nc.Details.Curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + b, err := proto.Marshal(nc.getRawDetails()) if err != nil { return err } - sig, err := key.Sign(rand.Reader, b, crypto.Hash(0)) - if err != nil { - return err + var sig []byte + + switch curve { + case Curve_CURVE25519: + signer := ed25519.PrivateKey(key) + sig = ed25519.Sign(signer, b) + case Curve_P256: + signer := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) + + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(b) + sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) + if err != nil { + return err + } + default: + return fmt.Errorf("invalid curve: %s", nc.Details.Curve) } + nc.Signature = sig return nil } // CheckSignature verifies the signature against the provided public key -func (nc *NebulaCertificate) CheckSignature(key ed25519.PublicKey) bool { +func (nc *NebulaCertificate) CheckSignature(key []byte) bool { b, err := proto.Marshal(nc.getRawDetails()) if err != nil { return false } - return ed25519.Verify(key, b, nc.Signature) + switch nc.Details.Curve { + case Curve_CURVE25519: + return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature) + default: + return false + } +} + +// NOTE: This uses an internal cache that will not be invalidated automatically +// if you manually change any fields in the NebulaCertificate. +func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool { + if !useCache { + return nc.CheckSignature(key) + } + + if v := nc.signatureVerified.Load(); v != nil { + return bytes.Equal(*v, key) + } + + verified := nc.CheckSignature(key) + if verified { + keyCopy := make([]byte, len(key)) + copy(keyCopy, key) + nc.signatureVerified.Store(&keyCopy) + } + + return verified } // Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false @@ -253,8 +602,27 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool { // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { - if ncp.IsBlocklisted(nc) { - return false, fmt.Errorf("certificate has been blocked") + return nc.verify(t, ncp, false) +} + +// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) +// +// NOTE: This uses an internal cache that will not be invalidated automatically +// if you manually change any fields in the NebulaCertificate. +func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) { + return nc.verify(t, ncp, true) +} + +// ResetCache resets the cache used by VerifyWithCache. +func (nc *NebulaCertificate) ResetCache() { + nc.sha256sum.Store(nil) + nc.signatureVerified.Store(nil) +} + +// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) +func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) { + if ncp.isBlocklistedWithCache(nc, useCache) { + return false, ErrBlockListed } signer, err := ncp.GetCAForCert(nc) @@ -263,15 +631,15 @@ func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error } if signer.Expired(t) { - return false, fmt.Errorf("root certificate is expired") + return false, ErrRootExpired } if nc.Expired(t) { - return false, fmt.Errorf("certificate is expired") + return false, ErrExpired } - if !nc.CheckSignature(signer.Details.PublicKey) { - return false, fmt.Errorf("certificate signature did not match") + if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) { + return false, ErrSignatureMismatch } if err := nc.CheckRootConstrains(signer); err != nil { @@ -324,22 +692,52 @@ func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) erro } // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match -func (nc *NebulaCertificate) VerifyPrivateKey(key []byte) error { +func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != nc.Details.Curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } if nc.Details.IsCA { - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } - if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { - return fmt.Errorf("public key in cert and private key supplied don't match") + if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key as P256") + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, nc.Details.PublicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + default: + return fmt.Errorf("invalid curve: %s", curve) } return nil } - pub, err := curve25519.X25519(key, curve25519.Basepoint) - if err != nil { - return err + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return err + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return err + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) } if !bytes.Equal(pub, nc.Details.PublicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") @@ -393,6 +791,7 @@ func (nc *NebulaCertificate) String() string { s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) + s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve) s += "\t}\n" fp, err := nc.Sha256Sum() if err == nil { @@ -413,6 +812,7 @@ func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { NotAfter: nc.Details.NotAfter.Unix(), PublicKey: make([]byte, len(nc.Details.PublicKey)), IsCA: nc.Details.IsCA, + Curve: nc.Details.Curve, } for _, ipNet := range nc.Details.Ips { @@ -461,6 +861,25 @@ func (nc *NebulaCertificate) Sha256Sum() (string, error) { return hex.EncodeToString(sum[:]), nil } +// NOTE: This uses an internal cache that will not be invalidated automatically +// if you manually change any fields in the NebulaCertificate. +func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) { + if !useCache { + return nc.Sha256Sum() + } + + if s := nc.sha256sum.Load(); s != nil { + return *s, nil + } + s, err := nc.Sha256Sum() + if err != nil { + return s, err + } + + nc.sha256sum.Store(&s) + return s, nil +} + func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { toString := func(ips []*net.IPNet) []string { s := []string{} @@ -482,6 +901,7 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), "isCa": nc.Details.IsCA, "issuer": nc.Details.Issuer, + "curve": nc.Details.Curve.String(), }, "fingerprint": fp, "signature": fmt.Sprintf("%x", nc.Signature), diff --git a/cert/cert.pb.go b/cert/cert.pb.go index 094aefb40..3570e0750 100644 --- a/cert/cert.pb.go +++ b/cert/cert.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 -// protoc v3.20.0 +// protoc-gen-go v1.30.0 +// protoc v3.21.5 // source: cert.proto package cert @@ -20,6 +20,52 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type Curve int32 + +const ( + Curve_CURVE25519 Curve = 0 + Curve_P256 Curve = 1 +) + +// Enum value maps for Curve. +var ( + Curve_name = map[int32]string{ + 0: "CURVE25519", + 1: "P256", + } + Curve_value = map[string]int32{ + "CURVE25519": 0, + "P256": 1, + } +) + +func (x Curve) Enum() *Curve { + p := new(Curve) + *p = x + return p +} + +func (x Curve) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Curve) Descriptor() protoreflect.EnumDescriptor { + return file_cert_proto_enumTypes[0].Descriptor() +} + +func (Curve) Type() protoreflect.EnumType { + return &file_cert_proto_enumTypes[0] +} + +func (x Curve) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Curve.Descriptor instead. +func (Curve) EnumDescriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{0} +} + type RawNebulaCertificate struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -91,6 +137,7 @@ type RawNebulaCertificateDetails struct { IsCA bool `protobuf:"varint,8,opt,name=IsCA,proto3" json:"IsCA,omitempty"` // sha-256 of the issuer certificate, if this field is blank the cert is self-signed Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,proto3" json:"Issuer,omitempty"` + Curve Curve `protobuf:"varint,100,opt,name=curve,proto3,enum=cert.Curve" json:"curve,omitempty"` } func (x *RawNebulaCertificateDetails) Reset() { @@ -188,6 +235,202 @@ func (x *RawNebulaCertificateDetails) GetIssuer() []byte { return nil } +func (x *RawNebulaCertificateDetails) GetCurve() Curve { + if x != nil { + return x.Curve + } + return Curve_CURVE25519 +} + +type RawNebulaEncryptedData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionMetadata *RawNebulaEncryptionMetadata `protobuf:"bytes,1,opt,name=EncryptionMetadata,proto3" json:"EncryptionMetadata,omitempty"` + Ciphertext []byte `protobuf:"bytes,2,opt,name=Ciphertext,proto3" json:"Ciphertext,omitempty"` +} + +func (x *RawNebulaEncryptedData) Reset() { + *x = RawNebulaEncryptedData{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptedData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptedData) ProtoMessage() {} + +func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{2} +} + +func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { + if x != nil { + return x.EncryptionMetadata + } + return nil +} + +func (x *RawNebulaEncryptedData) GetCiphertext() []byte { + if x != nil { + return x.Ciphertext + } + return nil +} + +type RawNebulaEncryptionMetadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionAlgorithm string `protobuf:"bytes,1,opt,name=EncryptionAlgorithm,proto3" json:"EncryptionAlgorithm,omitempty"` + Argon2Parameters *RawNebulaArgon2Parameters `protobuf:"bytes,2,opt,name=Argon2Parameters,proto3" json:"Argon2Parameters,omitempty"` +} + +func (x *RawNebulaEncryptionMetadata) Reset() { + *x = RawNebulaEncryptionMetadata{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptionMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptionMetadata) ProtoMessage() {} + +func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{3} +} + +func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { + if x != nil { + return x.EncryptionAlgorithm + } + return "" +} + +func (x *RawNebulaEncryptionMetadata) GetArgon2Parameters() *RawNebulaArgon2Parameters { + if x != nil { + return x.Argon2Parameters + } + return nil +} + +type RawNebulaArgon2Parameters struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` // rune in Go + Memory uint32 `protobuf:"varint,2,opt,name=memory,proto3" json:"memory,omitempty"` + Parallelism uint32 `protobuf:"varint,4,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // uint8 in Go + Iterations uint32 `protobuf:"varint,3,opt,name=iterations,proto3" json:"iterations,omitempty"` + Salt []byte `protobuf:"bytes,5,opt,name=salt,proto3" json:"salt,omitempty"` +} + +func (x *RawNebulaArgon2Parameters) Reset() { + *x = RawNebulaArgon2Parameters{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaArgon2Parameters) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaArgon2Parameters) ProtoMessage() {} + +func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. +func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{4} +} + +func (x *RawNebulaArgon2Parameters) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetMemory() uint32 { + if x != nil { + return x.Memory + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetParallelism() uint32 { + if x != nil { + return x.Parallelism + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetIterations() uint32 { + if x != nil { + return x.Iterations + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetSalt() []byte { + if x != nil { + return x.Salt + } + return nil +} + var File_cert_proto protoreflect.FileDescriptor var file_cert_proto_rawDesc = []byte{ @@ -199,7 +442,7 @@ var file_cert_proto_rawDesc = []byte{ 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xf9, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, + 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73, @@ -215,9 +458,43 @@ var file_cert_proto_rawDesc = []byte{ 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, - 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63, + 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, + 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, + 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, + 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, + 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, + 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, + 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, + 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, + 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, + 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, + 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, + 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, + 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, + 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, + 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65, + 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00, + 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, + 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -232,18 +509,26 @@ func file_cert_proto_rawDescGZIP() []byte { return file_cert_proto_rawDescData } -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_cert_proto_goTypes = []interface{}{ - (*RawNebulaCertificate)(nil), // 0: cert.RawNebulaCertificate - (*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails + (Curve)(0), // 0: cert.Curve + (*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate + (*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails + (*RawNebulaEncryptedData)(nil), // 3: cert.RawNebulaEncryptedData + (*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata + (*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters } var file_cert_proto_depIdxs = []int32{ - 1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails + 0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve + 4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata + 5, // 3: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_cert_proto_init() } @@ -276,19 +561,56 @@ func file_cert_proto_init() { return nil } } + file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptedData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptionMetadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaArgon2Parameters); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_cert_proto_rawDesc, - NumEnums: 0, - NumMessages: 2, + NumEnums: 1, + NumMessages: 5, NumExtensions: 0, NumServices: 0, }, GoTypes: file_cert_proto_goTypes, DependencyIndexes: file_cert_proto_depIdxs, + EnumInfos: file_cert_proto_enumTypes, MessageInfos: file_cert_proto_msgTypes, }.Build() File_cert_proto = out.File diff --git a/cert/cert.proto b/cert/cert.proto index e135dd1fd..36d043bcd 100644 --- a/cert/cert.proto +++ b/cert/cert.proto @@ -5,6 +5,11 @@ option go_package = "github.com/slackhq/nebula/cert"; //import "google/protobuf/timestamp.proto"; +enum Curve { + CURVE25519 = 0; + P256 = 1; +} + message RawNebulaCertificate { RawNebulaCertificateDetails Details = 1; bytes Signature = 2; @@ -26,4 +31,24 @@ message RawNebulaCertificateDetails { // sha-256 of the issuer certificate, if this field is blank the cert is self-signed bytes Issuer = 9; -} \ No newline at end of file + + Curve curve = 100; +} + +message RawNebulaEncryptedData { + RawNebulaEncryptionMetadata EncryptionMetadata = 1; + bytes Ciphertext = 2; +} + +message RawNebulaEncryptionMetadata { + string EncryptionAlgorithm = 1; + RawNebulaArgon2Parameters Argon2Parameters = 2; +} + +message RawNebulaArgon2Parameters { + int32 version = 1; // rune in Go + uint32 memory = 2; + uint32 parallelism = 4; // uint8 in Go + uint32 iterations = 3; + bytes salt = 5; +} diff --git a/cert/cert_test.go b/cert/cert_test.go index 5a8274152..30e99eca1 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -1,6 +1,9 @@ package cert import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "fmt" "io" @@ -101,7 +104,49 @@ func TestNebulaCertificate_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) assert.Nil(t, err) assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(priv)) + assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) + assert.True(t, nc.CheckSignature(pub)) + + _, err = nc.Marshal() + assert.Nil(t, err) + //t.Log("Cert size:", len(b)) +} + +func TestNebulaCertificate_SignP256(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Curve: Curve_P256, + Issuer: "1234567890abcedfghij1234567890ab", + }, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + assert.Nil(t, err) + assert.False(t, nc.CheckSignature(pub)) + assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) assert.True(t, nc.CheckSignature(pub)) _, err = nc.Marshal() @@ -153,7 +198,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", string(b), ) } @@ -177,7 +222,7 @@ func TestNebulaCertificate_Verify(t *testing.T) { v, err := c.Verify(time.Now(), caPool) assert.False(t, v) - assert.EqualError(t, err, "certificate has been blocked") + assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() v, err = c.Verify(time.Now(), caPool) @@ -217,6 +262,65 @@ func TestNebulaCertificate_Verify(t *testing.T) { assert.Nil(t, err) } +func TestNebulaCertificate_VerifyP256(t *testing.T) { + ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + + h, err := ca.Sha256Sum() + assert.Nil(t, err) + + caPool := NewCAPool() + caPool.CAs[h] = ca + + f, err := c.Sha256Sum() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + v, err := c.Verify(time.Now(), caPool) + assert.False(t, v) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + v, err = c.Verify(time.Now(), caPool) + assert.True(t, v) + assert.Nil(t, err) + + v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) + assert.False(t, v) + assert.EqualError(t, err, "root certificate is expired") + + c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) + assert.False(t, v) + assert.EqualError(t, err, "certificate is expired") + + // Test group assertion + ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) + assert.Nil(t, err) + + caPem, err := ca.MarshalToPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + caPool.AddCACertificate(caPem) + + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) + assert.Nil(t, err) + v, err = c.Verify(time.Now(), caPool) + assert.False(t, v) + assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") + + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) + assert.Nil(t, err) + v, err = c.Verify(time.Now(), caPool) + assert.True(t, v) + assert.Nil(t, err) +} + func TestNebulaCertificate_Verify_IPs(t *testing.T) { _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") @@ -378,20 +482,40 @@ func TestNebulaCertificate_Verify_Subnets(t *testing.T) { func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) assert.Nil(t, err) - err = ca.VerifyPrivateKey(caKey) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) assert.Nil(t, err) - err = ca.VerifyPrivateKey(caKey2) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.NotNil(t, err) c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - err = c.VerifyPrivateKey(priv) + err = c.VerifyPrivateKey(Curve_CURVE25519, priv) assert.Nil(t, err) _, priv2 := x25519Keypair() - err = c.VerifyPrivateKey(priv2) + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.NotNil(t, err) +} + +func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + err = c.VerifyPrivateKey(Curve_P256, priv) + assert.Nil(t, err) + + _, priv2 := p256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) assert.NotNil(t, err) } @@ -438,6 +562,16 @@ CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= -----END NEBULA CERTIFICATE----- +` + + p256 := ` +# p256 certificate +-----BEGIN NEBULA CERTIFICATE----- +CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 +6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H +76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC +IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX +-----END NEBULA CERTIFICATE----- ` rootCA := NebulaCertificate{ @@ -452,6 +586,12 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= }, } + rootCAP256 := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "nebula P256 test", + }, + } + p, err := NewCAPoolFromBytes([]byte(noNewLines)) assert.Nil(t, err) assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) @@ -474,6 +614,11 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") assert.Equal(t, len(pppp.CAs), 3) + + ppppp, err := NewCAPoolFromBytes([]byte(p256)) + assert.Nil(t, err) + assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) + assert.Equal(t, len(ppppp.CAs), 1) } func appendByteSlices(b ...[]byte) []byte { @@ -529,11 +674,16 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB assert.EqualError(t, err, "input did not contain a valid PEM encoded block") } -func TestUnmarshalEd25519PrivateKey(t *testing.T) { +func TestUnmarshalSigningPrivateKey(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA ED25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NEBULA ED25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA ECDSA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ECDSA P256 PRIVATE KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA ED25519 PRIVATE KEY----- @@ -550,39 +700,139 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -END NEBULA ED25519 PRIVATE KEY-----`) - keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) // Success test case - k, rest, err := UnmarshalEd25519PrivateKey(keyBundle) + k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle) assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalSigningPrivateKey(rest) + assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) assert.Nil(t, err) // Fail due to short key - k, rest, err = UnmarshalEd25519PrivateKey(rest) + k, rest, curve, err = UnmarshalSigningPrivateKey(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner - k, rest, err = UnmarshalEd25519PrivateKey(rest) + k, rest, curve, err = UnmarshalSigningPrivateKey(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 private key banner") + assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalEd25519PrivateKey(rest) + k, rest, curve, err = UnmarshalSigningPrivateKey(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) + assert.Nil(t, err) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) + assert.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, rest, []byte{}) } -func TestUnmarshalX25519PrivateKey(t *testing.T) { +func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) + assert.Nil(t, err) + + // Verify the "key" can be decrypted successfully + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, []byte{}) + assert.Nil(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} + +func TestUnmarshalPrivateKey(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA X25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA X25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PRIVATE KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA X25519 PRIVATE KEY----- @@ -599,29 +849,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA X25519 PRIVATE KEY-----`) - keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) // Success test case - k, rest, err := UnmarshalX25519PrivateKey(keyBundle) + k, rest, curve, err := UnmarshalPrivateKey(keyBundle) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalPrivateKey(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) assert.Nil(t, err) // Fail due to short key - k, rest, err = UnmarshalX25519PrivateKey(rest) + k, rest, curve, err = UnmarshalPrivateKey(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 private key") + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner - k, rest, err = UnmarshalX25519PrivateKey(rest) + k, rest, curve, err = UnmarshalPrivateKey(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 private key banner") + assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalX25519PrivateKey(rest) + k, rest, curve, err = UnmarshalPrivateKey(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) assert.EqualError(t, err, "input did not contain a valid PEM encoded block") @@ -681,6 +939,12 @@ func TestUnmarshalX25519PublicKey(t *testing.T) { -----BEGIN NEBULA X25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA X25519 PUBLIC KEY----- +`) + pubP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA X25519 PUBLIC KEY----- @@ -697,29 +961,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA X25519 PUBLIC KEY-----`) - keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) + keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) // Success test case - k, rest, err := UnmarshalX25519PublicKey(keyBundle) + k, rest, curve, err := UnmarshalPublicKey(keyBundle) assert.Equal(t, len(k), 32) assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + + // Success test case + k, rest, curve, err = UnmarshalPublicKey(rest) + assert.Equal(t, len(k), 65) + assert.Nil(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) // Fail due to short key - k, rest, err = UnmarshalX25519PublicKey(rest) + k, rest, curve, err = UnmarshalPublicKey(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 public key") + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner - k, rest, err = UnmarshalX25519PublicKey(rest) + k, rest, curve, err = UnmarshalPublicKey(rest) assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 public key banner") + assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalX25519PublicKey(rest) + k, rest, curve, err = UnmarshalPublicKey(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) assert.EqualError(t, err, "input did not contain a valid PEM encoded block") @@ -816,13 +1088,56 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] nc.Details.Groups = groups } - err = nc.Sign(priv) + err = nc.Sign(Curve_CURVE25519, priv) if err != nil { return nil, nil, nil, err } return nc, pub, priv, nil } +func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + nc := &NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + Curve: Curve_P256, + InvertedGroups: make(map[string]struct{}), + }, + } + + if len(ips) > 0 { + nc.Details.Ips = ips + } + + if len(subnets) > 0 { + nc.Details.Subnets = subnets + } + + if len(groups) > 0 { + nc.Details.Groups = groups + } + + err = nc.Sign(Curve_P256, rawPriv) + if err != nil { + return nil, nil, nil, err + } + return nc, pub, rawPriv, nil +} + func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { issuer, err := ca.Sha256Sum() if err != nil { @@ -856,7 +1171,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips } } - pub, rawPriv := x25519Keypair() + var pub, rawPriv []byte + + switch ca.Details.Curve { + case Curve_CURVE25519: + pub, rawPriv = x25519Keypair() + case Curve_P256: + pub, rawPriv = p256Keypair() + default: + return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve) + } nc := &NebulaCertificate{ Details: NebulaCertificateDetails{ @@ -868,12 +1192,13 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, IsCA: false, + Curve: ca.Details.Curve, Issuer: issuer, InvertedGroups: make(map[string]struct{}), }, } - err = nc.Sign(key) + err = nc.Sign(ca.Details.Curve, key) if err != nil { return nil, nil, nil, err } @@ -894,3 +1219,12 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } + +func p256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cert/crypto.go b/cert/crypto.go new file mode 100644 index 000000000..3558e1a54 --- /dev/null +++ b/cert/crypto.go @@ -0,0 +1,143 @@ +package cert + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + + "golang.org/x/crypto/argon2" +) + +// KDF factors +type Argon2Parameters struct { + version rune + Memory uint32 // KiB + Parallelism uint8 + Iterations uint32 + salt []byte +} + +// Returns a new Argon2Parameters object with current version set +func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { + return &Argon2Parameters{ + version: argon2.Version, + Memory: memory, // KiB + Parallelism: parallelism, + Iterations: iterations, + } +} + +// Encrypts data using AES-256-GCM and the Argon2id key derivation function +func aes256Encrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + // this should never happen, but since this dictates how our calls into the + // aes package behave and could be catastraphic, let's sanity check this + if len(key) != 32 { + return nil, fmt.Errorf("invalid AES-256 key length (%d) - cowardly refusing to encrypt", len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nonce, data, nil) + blob := joinNonceCiphertext(nonce, ciphertext) + + return blob, nil +} + +// Decrypts data using AES-256-GCM and the Argon2id key derivation function +// Expects the data to include an Argon2id parameter string before the encrypted data +func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize()) + if err != nil { + return nil, err + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("invalid passphrase or corrupt private key") + } + + return plaintext, nil +} + +func aes256DeriveKey(passphrase []byte, params *Argon2Parameters) ([]byte, error) { + if params.salt == nil { + params.salt = make([]byte, 32) + if _, err := rand.Read(params.salt); err != nil { + return nil, err + } + } + + // keySize of 32 bytes will result in AES-256 encryption + key, err := deriveKey(passphrase, 32, params) + if err != nil { + return nil, err + } + + return key, nil +} + +// Derives a key from a passphrase using Argon2id +func deriveKey(passphrase []byte, keySize uint32, params *Argon2Parameters) ([]byte, error) { + if params.version != argon2.Version { + return nil, fmt.Errorf("incompatible Argon2 version: %d", params.version) + } + + if params.salt == nil { + return nil, fmt.Errorf("salt must be set in argon2Parameters") + } else if len(params.salt) < 16 { + return nil, fmt.Errorf("salt must be at least 128 bits") + } + + key := argon2.IDKey(passphrase, params.salt, params.Iterations, params.Memory, params.Parallelism, keySize) + + return key, nil +} + +// Prepends nonce to ciphertext +func joinNonceCiphertext(nonce []byte, ciphertext []byte) []byte { + return append(nonce, ciphertext...) +} + +// Splits nonce from ciphertext +func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { + if len(blob) <= nonceSize { + return nil, nil, fmt.Errorf("invalid ciphertext blob - blob shorter than nonce length") + } + + return blob[:nonceSize], blob[nonceSize:], nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go new file mode 100644 index 000000000..c2e61df07 --- /dev/null +++ b/cert/crypto_test.go @@ -0,0 +1,25 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/argon2" +) + +func TestNewArgon2Parameters(t *testing.T) { + p := NewArgon2Parameters(64*1024, 4, 3) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 64 * 1024, + Parallelism: 4, + Iterations: 3, + }, p) + p = NewArgon2Parameters(2*1024*1024, 2, 1) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 2 * 1024 * 1024, + Parallelism: 2, + Iterations: 1, + }, p) +} diff --git a/cert/errors.go b/cert/errors.go index 3135467ea..05b42d10c 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -1,9 +1,14 @@ package cert -import "errors" +import ( + "errors" +) var ( - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrSignatureMismatch = errors.New("certificate signature did not match") ) diff --git a/cidr/tree4.go b/cidr/tree4.go index 28d0e784d..c5ebe54a7 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -6,28 +6,36 @@ import ( "github.com/slackhq/nebula/iputil" ) -type Node struct { - left *Node - right *Node - parent *Node - value interface{} +type Node[T any] struct { + left *Node[T] + right *Node[T] + parent *Node[T] + hasValue bool + value T } -type Tree4 struct { - root *Node +type entry[T any] struct { + CIDR *net.IPNet + Value T +} + +type Tree4[T any] struct { + root *Node[T] + list []entry[T] } const ( startbit = iputil.VpnIp(0x80000000) ) -func NewTree4() *Tree4 { - tree := new(Tree4) - tree.root = &Node{} +func NewTree4[T any]() *Tree4[T] { + tree := new(Tree4[T]) + tree.root = &Node[T]{} + tree.list = []entry[T]{} return tree } -func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { +func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { bit := startbit node := tree.root next := tree.root @@ -53,13 +61,23 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // We already have this range so update the value if next != nil { + addCIDR := cidr.String() + for i, v := range tree.list { + if addCIDR == v.CIDR.String() { + tree.list = append(tree.list[:i], tree.list[i+1:]...) + break + } + } + + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) node.value = val + node.hasValue = true return } // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -74,16 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + node.hasValue = true + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) } -// Finds the first match, which may be the least specific -func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { +// Contains finds the first match, which may be the least specific +func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { - return node.value + if node.hasValue { + return true, node.value } if ip&bit != 0 { @@ -96,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { } - return value + return false, value } -// Finds the most specific match -func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { +// MostSpecificContains finds the most specific match +func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -118,17 +139,25 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit >>= 1 } - return value + return ok, value } -// Finds the most specific match -func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { +type eachFunc[T any] func(T) bool + +// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete +// The final return value will be true if the provided function returned true +func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { bit := startbit node := tree.root - lastNode := node for node != nil { - lastNode = node + if node.hasValue { + // If the each func returns true then we can exit the loop + if each(node.value) { + return true + } + } + if ip&bit != 0 { node = node.right } else { @@ -138,8 +167,37 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { bit >>= 1 } - if bit == 0 && lastNode != nil { - value = lastNode.value + return false +} + +// GetCIDR returns the entry added by the most recent matching AddCIDR call +func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { + bit := startbit + node := tree.root + + ip := iputil.Ip2VpnIp(cidr.IP) + mask := iputil.Ip2VpnIp(cidr.Mask) + + // Find our last ancestor in the tree + for node != nil && bit&mask != 0 { + if ip&bit != 0 { + node = node.right + } else { + node = node.left + } + + bit = bit >> 1 + } + + if bit&mask == 0 && node != nil { + value = node.value + ok = node.hasValue } - return value + + return ok, value +} + +// List will return all CIDRs and their current values. Do not modify the contents! +func (tree *Tree4[T]) List() []entry[T] { + return tree.list } diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go index 07f2b0aed..cd17be4dc 100644 --- a/cidr/tree4_test.go +++ b/cidr/tree4_test.go @@ -8,8 +8,22 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCIDRTree_List(t *testing.T) { + tree := NewTree4[string]() + tree.AddCIDR(Parse("1.0.0.0/16"), "1") + tree.AddCIDR(Parse("1.0.0.0/8"), "2") + tree.AddCIDR(Parse("1.0.0.0/16"), "3") + tree.AddCIDR(Parse("1.0.0.0/16"), "4") + list := tree.List() + assert.Len(t, list, 2) + assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) + assert.Equal(t, "2", list[0].Value) + assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) + assert.Equal(t, "4", list[1].Value) +} + func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -19,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4a", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4a", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -57,59 +79,76 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } -func TestCIDRTree_Match(t *testing.T) { - tree := NewTree4() - tree.AddCIDR(Parse("4.1.1.0/32"), "1a") - tree.AddCIDR(Parse("4.1.1.1/32"), "1b") +func TestTree4_GetCIDR(t *testing.T) { + tree := NewTree4[string]() + tree.AddCIDR(Parse("1.0.0.0/8"), "1") + tree.AddCIDR(Parse("2.1.0.0/16"), "2") + tree.AddCIDR(Parse("3.1.1.0/24"), "3") + tree.AddCIDR(Parse("4.1.1.0/24"), "4a") + tree.AddCIDR(Parse("4.1.1.1/32"), "4b") + tree.AddCIDR(Parse("4.1.2.1/32"), "4c") + tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} - IP string + IPNet *net.IPNet }{ - {"1a", "4.1.1.0"}, - {"1b", "4.1.1.1"}, + {true, "1", Parse("1.0.0.0/8")}, + {true, "2", Parse("2.1.0.0/16")}, + {true, "3", Parse("3.1.1.0/24")}, + {true, "4a", Parse("4.1.1.0/24")}, + {true, "4b", Parse("4.1.1.1/32")}, + {true, "4c", Parse("4.1.2.1/32")}, + {true, "5", Parse("254.0.0.0/4")}, + {false, "", Parse("2.0.0.0/8")}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.GetCIDR(tt.IPNet) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - - tree = NewTree4() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) } func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.1.0.0/16"), "1") tree.AddCIDR(Parse("1.2.1.1/32"), "1") tree.AddCIDR(Parse("192.2.1.1/32"), "1") @@ -129,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) { } }) } - -func BenchmarkCIDRTree_Match(b *testing.B) { - tree := NewTree4() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Match(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Match(ip) - } - }) -} diff --git a/cidr/tree6.go b/cidr/tree6.go index d13c93d51..3f2cd2a48 100644 --- a/cidr/tree6.go +++ b/cidr/tree6.go @@ -8,20 +8,20 @@ import ( const startbit6 = uint64(1 << 63) -type Tree6 struct { - root4 *Node - root6 *Node +type Tree6[T any] struct { + root4 *Node[T] + root6 *Node[T] } -func NewTree6() *Tree6 { - tree := new(Tree6) - tree.root4 = &Node{} - tree.root6 = &Node{} +func NewTree6[T any]() *Tree6[T] { + tree := new(Tree6[T]) + tree.root4 = &Node[T]{} + tree.root6 = &Node[T]{} return tree } -func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { - var node, next *Node +func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { + var node, next *Node[T] cidrIP, ipv4 := isIPV4(cidr.IP) if ipv4 { @@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + node.hasValue = true } // Finds the most specific match -func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { - var node *Node +func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { + var node *Node[T] wholeIP, ipv4 := isIPV4(ip) if ipv4 { @@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { bit := startbit for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { } } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root4 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) bit >>= 1 } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { ip := hi node := tree.root6 @@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { bit := startbit6 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { ip = lo } - return value + return ok, value } func isIPV4(ip net.IP) (net.IP, bool) { diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go index b6dc4c266..eb159ec74 100644 --- a/cidr/tree6_test.go +++ b/cidr/tree6_test.go @@ -9,7 +9,7 @@ import ( ) func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP))) + ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree6() + tree = NewTree6[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") tree.AddCIDR(Parse("::/0"), "cool6") - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0"))) - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))) + ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("::")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) } func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, } for _, tt := range tests { @@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { hi := binary.BigEndian.Uint64(ip[:8]) lo := binary.BigEndian.Uint64(ip[8:]) - assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo)) + ok, r := tree.MostSpecificContainsIpV6(hi, lo) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } } diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index ce8d5face..69df4ab05 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -1,11 +1,13 @@ package main import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "flag" "fmt" "io" - "io/ioutil" + "math" "net" "os" "strings" @@ -17,15 +19,21 @@ import ( ) type caFlags struct { - set *flag.FlagSet - name *string - duration *time.Duration - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - ips *string - subnets *string + set *flag.FlagSet + name *string + duration *time.Duration + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + ips *string + subnets *string + argonMemory *uint + argonIterations *uint + argonParallelism *uint + encryption *bool + + curve *string } func newCaFlags() *caFlags { @@ -39,10 +47,29 @@ func newCaFlags() *caFlags { cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") + cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") + cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") + cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") + cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") return &cf } -func ca(args []string, out io.Writer, errOut io.Writer) error { +func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert.Argon2Parameters, error) { + if memory <= 0 || memory > math.MaxUint32 { + return nil, newHelpErrorf("-argon-memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if parallelism <= 0 || parallelism > math.MaxUint8 { + return nil, newHelpErrorf("-argon-parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if iterations <= 0 || iterations > math.MaxUint32 { + return nil, newHelpErrorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil +} + +func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { cf := newCaFlags() err := cf.set.Parse(args) if err != nil { @@ -58,6 +85,12 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } + var kdfParams *cert.Argon2Parameters + if *cf.encryption { + if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { + return err + } + } if *cf.duration <= 0 { return &helpError{"-duration must be greater than 0"} @@ -109,9 +142,47 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { } } - pub, rawPriv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ed25519 keys: %s", err) + var passphrase []byte + if *cf.encryption { + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("out-key must be encrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading passphrase: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + + if len(passphrase) == 0 { + return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") + } + } + + var curve cert.Curve + var pub, rawPriv []byte + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + curve = cert.Curve_CURVE25519 + pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ed25519 keys: %s", err) + } + case "P256": + var key *ecdsa.PrivateKey + curve = cert.Curve_P256 + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ecdsa keys: %s", err) + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L60 + rawPriv = key.D.FillBytes(make([]byte, 32)) + pub = elliptic.Marshal(elliptic.P256(), key.X, key.Y) } nc := cert.NebulaCertificate{ @@ -124,6 +195,7 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { NotAfter: time.Now().Add(*cf.duration), PublicKey: pub, IsCA: true, + Curve: curve, }, } @@ -135,22 +207,32 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } - err = nc.Sign(rawPriv) + err = nc.Sign(curve, rawPriv) if err != nil { return fmt.Errorf("error while signing: %s", err) } - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + var b []byte + if *cf.encryption { + b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + } else { + b = cert.MarshalSigningPrivateKey(curve, rawPriv) + } + + err = os.WriteFile(*cf.outKeyPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } - b, err := nc.MarshalToPEM() + b, err = nc.MarshalToPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } - err = ioutil.WriteFile(*cf.outCertPath, b, 0600) + err = os.WriteFile(*cf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -161,7 +243,7 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*cf.outQRPath, b, 0600) + err = os.WriteFile(*cf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 372a4f19b..3a534051d 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -5,8 +5,10 @@ package main import ( "bytes" - "io/ioutil" + "encoding/pem" + "errors" "os" + "strings" "testing" "time" @@ -26,8 +28,18 @@ func Test_caHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " -argon-iterations uint\n"+ + " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ + " -argon-memory uint\n"+ + " \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+ + " -argon-parallelism uint\n"+ + " \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+ + " -curve string\n"+ + " \tEdDSA/ECDSA Curve (25519, P256) (default \"25519\")\n"+ " -duration duration\n"+ " \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+ + " -encrypt\n"+ + " \tOptional: prompt for passphrase and write out-key in an encrypted format\n"+ " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ @@ -50,18 +62,38 @@ func Test_ca(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + + pwPromptOb := "Enter passphrase: " + // required args - assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required") + assertHelpError(t, ca( + []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -69,12 +101,12 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) os.Remove(keyF.Name()) @@ -82,12 +114,12 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp cert file - crtF, err := ioutil.TempFile("", "test.crt") + crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) os.Remove(crtF.Name()) os.Remove(keyF.Name()) @@ -96,18 +128,18 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 64) - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -122,19 +154,67 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", lCrt.Details.Issuer) assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + // test encrypted key + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, ca(args, ob, eb, testpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // read encrypted key file and verify default params + rb, _ = os.ReadFile(keyF.Name()) + k, _ := pem.Decode(rb) + ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) + assert.Nil(t, err) + // we won't know salt in advance, so just check start of string + assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) + assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) + assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations) + + // verify the key is valid and decrypt-able + var curve cert.Curve + curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Nil(t, err) + assert.Len(t, b, 0) + assert.Len(t, lKey, 64) + + // test when reading passsword results in an error + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Error(t, ca(args, ob, eb, errpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // test when user fails to enter a password + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up + assert.Equal(t, "", eb.String()) + // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -143,7 +223,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 4f15af8dc..d94cbf145 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "github.com/slackhq/nebula/cert" @@ -14,6 +13,8 @@ type keygenFlags struct { set *flag.FlagSet outKeyPath *string outPubPath *string + + curve *string } func newKeygenFlags() *keygenFlags { @@ -21,6 +22,7 @@ func newKeygenFlags() *keygenFlags { cf.set.Usage = func() {} cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") + cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)") return &cf } @@ -38,14 +40,25 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return err } - pub, rawPriv := x25519Keypair() + var pub, rawPriv []byte + var curve cert.Curve + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + pub, rawPriv = x25519Keypair() + curve = cert.Curve_CURVE25519 + case "P256": + pub, rawPriv = p256Keypair() + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) + } - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600) + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } - err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalX25519PublicKey(pub), 0600) + err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 52f71f516..9a3b3f3bb 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "io/ioutil" "os" "testing" @@ -22,6 +21,8 @@ func Test_keygenHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ + " -curve string\n"+ + " \tECDH Curve (25519, P256) (default \"25519\")\n"+ " -out-key string\n"+ " \tRequired: path to write the private key to\n"+ " -out-pub string\n"+ @@ -52,7 +53,7 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) defer os.Remove(keyF.Name()) @@ -65,7 +66,7 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp pub file - pubF, err := ioutil.TempFile("", "test.pub") + pubF, err := os.CreateTemp("", "test.pub") assert.Nil(t, err) defer os.Remove(pubF.Name()) @@ -78,13 +79,13 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) - rb, _ = ioutil.ReadFile(pubF.Name()) + rb, _ = os.ReadFile(pubF.Name()) lPub, b, err := cert.UnmarshalX25519PublicKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index 3fba40ad1..b803d30a6 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -62,11 +62,11 @@ func main() { switch args[0] { case "ca": - err = ca(args[1:], os.Stdout, os.Stderr) + err = ca(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "keygen": err = keygen(args[1:], os.Stdout, os.Stderr) case "sign": - err = signCert(args[1:], os.Stdout, os.Stderr) + err = signCert(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "print": err = printCert(args[1:], os.Stdout, os.Stderr) case "verify": diff --git a/cmd/nebula-cert/passwords.go b/cmd/nebula-cert/passwords.go new file mode 100644 index 000000000..8129560ef --- /dev/null +++ b/cmd/nebula-cert/passwords.go @@ -0,0 +1,28 @@ +package main + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/term" +) + +var ErrNoTerminal = errors.New("cannot read password from nonexistent terminal") + +type PasswordReader interface { + ReadPassword() ([]byte, error) +} + +type StdinPasswordReader struct{} + +func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return nil, ErrNoTerminal + } + + password, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + + return password, err +} diff --git a/cmd/nebula-cert/passwords_test.go b/cmd/nebula-cert/passwords_test.go new file mode 100644 index 000000000..d0b64b936 --- /dev/null +++ b/cmd/nebula-cert/passwords_test.go @@ -0,0 +1,10 @@ +package main + +type StubPasswordReader struct { + password []byte + err error +} + +func (pr *StubPasswordReader) ReadPassword() ([]byte, error) { + return pr.password, pr.err +} diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 222dbc058..746d6a3ab 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "strings" @@ -41,7 +40,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCert, err := ioutil.ReadFile(*pf.path) + rawCert, err := os.ReadFile(*pf.path) if err != nil { return fmt.Errorf("unable to read cert; %s", err) } @@ -87,7 +86,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*pf.outQRPath, b, 0600) + err = os.WriteFile(*pf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 5d6a38e19..9fa8a5492 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "io/ioutil" "os" "testing" "time" @@ -54,7 +53,7 @@ func Test_printCert(t *testing.T) { // invalid cert at path ob.Reset() eb.Reset() - tf, err := ioutil.TempFile("", "print-cert") + tf, err := os.CreateTemp("", "print-cert") assert.Nil(t, err) defer os.Remove(tf.Name()) @@ -87,7 +86,7 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", + "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", ob.String(), ) assert.Equal(t, "", eb.String()) @@ -115,7 +114,7 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", ob.String(), ) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 4b3b899ff..35d644689 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -1,11 +1,11 @@ package main import ( + "crypto/ecdh" "crypto/rand" "flag" "fmt" "io" - "io/ioutil" "net" "os" "strings" @@ -49,7 +49,7 @@ func newSignFlags() *signFlags { } -func signCert(args []string, out io.Writer, errOut io.Writer) error { +func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { sf := newSignFlags() err := sf.set.Parse(args) if err != nil { @@ -72,17 +72,46 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return newHelpErrorf("cannot set both -in-pub and -out-key") } - rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath) + rawCAKey, err := os.ReadFile(*sf.caKeyPath) if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } - caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey) - if err != nil { + var curve cert.Curve + var caKey []byte + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) + if err == cert.ErrPrivateKeyEncrypted { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { return fmt.Errorf("error while parsing ca-key: %s", err) } - rawCACert, err := ioutil.ReadFile(*sf.caCertPath) + rawCACert, err := os.ReadFile(*sf.caCertPath) if err != nil { return fmt.Errorf("error while reading ca-crt: %s", err) } @@ -92,7 +121,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while parsing ca-crt: %s", err) } - if err := caCert.VerifyPrivateKey(caKey); err != nil { + if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { return fmt.Errorf("refusing to sign, root certificate does not match private key") } @@ -148,16 +177,20 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { var pub, rawPriv []byte if *sf.inPubPath != "" { - rawPub, err := ioutil.ReadFile(*sf.inPubPath) + rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } - pub, _, err = cert.UnmarshalX25519PublicKey(rawPub) + var pubCurve cert.Curve + pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) } + if pubCurve != curve { + return fmt.Errorf("curve of in-pub does not match ca") + } } else { - pub, rawPriv = x25519Keypair() + pub, rawPriv = newKeypair(curve) } nc := cert.NebulaCertificate{ @@ -171,6 +204,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { PublicKey: pub, IsCA: false, Issuer: issuer, + Curve: curve, }, } @@ -190,7 +224,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - err = nc.Sign(caKey) + err = nc.Sign(curve, caKey) if err != nil { return fmt.Errorf("error while signing: %s", err) } @@ -200,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } - err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600) + err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -211,7 +245,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while marshalling certificate: %s", err) } - err = ioutil.WriteFile(*sf.outCertPath, b, 0600) + err = os.WriteFile(*sf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -222,7 +256,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*sf.outQRPath, b, 0600) + err = os.WriteFile(*sf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -231,6 +265,17 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return nil } +func newKeypair(curve cert.Curve) ([]byte, []byte) { + switch curve { + case cert.Curve_CURVE25519: + return x25519Keypair() + case cert.Curve_P256: + return p256Keypair() + default: + return nil, nil + } +} + func x25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { @@ -245,6 +290,15 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } +func p256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} + func signSummary() string { return "sign : create and sign a certificate" } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4976fa368..adf83a267 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -6,7 +6,7 @@ package main import ( "bytes" "crypto/rand" - "io/ioutil" + "errors" "os" "testing" "time" @@ -58,17 +58,39 @@ func Test_signCert(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + // required args - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-ip is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -76,17 +98,17 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() - caKeyF, err := ioutil.TempFile("", "sign-cert.key") + caKeyF, err := os.CreateTemp("", "sign-cert.key") assert.Nil(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,19 +120,19 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed to unmarshal cert ob.Reset() eb.Reset() - caCrtF, err := ioutil.TempFile("", "sign-cert.crt") + caCrtF, err := os.CreateTemp("", "sign-cert.crt") assert.Nil(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -129,19 +151,19 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed to unmarshal pub ob.Reset() eb.Reset() - inPubF, err := ioutil.TempFile("", "in.pub") + inPubF, err := os.CreateTemp("", "in.pub") assert.Nil(t, err) defer os.Remove(inPubF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,14 +177,14 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -170,20 +192,20 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) - caKeyF2, err := ioutil.TempFile("", "sign-cert-2.key") + caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") assert.Nil(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2)) @@ -191,7 +213,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate does not match private key") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -199,12 +221,12 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) os.Remove(keyF.Name()) @@ -212,13 +234,13 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file - crtF, err := ioutil.TempFile("", "test.crt") + crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) os.Remove(crtF.Name()) @@ -226,18 +248,18 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -268,12 +290,12 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -283,7 +305,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -291,14 +313,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -306,14 +328,83 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) + + // create valid cert/key using encrypted CA key + os.Remove(caKeyF.Name()) + os.Remove(caCrtF.Name()) + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + + caKeyF, err = os.CreateTemp("", "sign-cert.key") + assert.Nil(t, err) + defer os.Remove(caKeyF.Name()) + + caCrtF, err = os.CreateTemp("", "sign-cert.crt") + assert.Nil(t, err) + defer os.Remove(caCrtF.Name()) + + // generate the encrypted key + caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader) + kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3) + b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams) + caKeyF.Write(b) + + ca = cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "ca", + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute * 200), + PublicKey: caPub, + IsCA: true, + }, + } + b, _ = ca.MarshalToPEM() + caCrtF.Write(b) + + // test with the proper password + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Nil(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the wrong password + ob.Reset() + eb.Reset() + + testpw.password = []byte("invalid password") + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the user not entering a password + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, nopw)) + // normally the user hitting enter on the prompt would add newlines between these + assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test an error condition + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, errpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) } diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 51b9a9303..c9559136f 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "strings" "time" @@ -40,7 +39,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := ioutil.ReadFile(*vf.caPath) + rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %s", err) } @@ -57,7 +56,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { } } - rawCert, err := ioutil.ReadFile(*vf.certPath) + rawCert, err := os.ReadFile(*vf.certPath) if err != nil { return fmt.Errorf("unable to read crt; %s", err) } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f56243354..f0f4c78dd 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "crypto/rand" - "io/ioutil" "os" "testing" "time" @@ -56,7 +55,7 @@ func Test_verify(t *testing.T) { // invalid ca at path ob.Reset() eb.Reset() - caFile, err := ioutil.TempFile("", "verify-ca") + caFile, err := os.CreateTemp("", "verify-ca") assert.Nil(t, err) defer os.Remove(caFile.Name()) @@ -77,7 +76,7 @@ func Test_verify(t *testing.T) { IsCA: true, }, } - ca.Sign(caPriv) + ca.Sign(cert.Curve_CURVE25519, caPriv) b, _ := ca.MarshalToPEM() caFile.Truncate(0) caFile.Seek(0, 0) @@ -92,7 +91,7 @@ func Test_verify(t *testing.T) { // invalid crt at path ob.Reset() eb.Reset() - certFile, err := ioutil.TempFile("", "verify-cert") + certFile, err := os.CreateTemp("", "verify-cert") assert.Nil(t, err) defer os.Remove(certFile.Name()) @@ -117,7 +116,7 @@ func Test_verify(t *testing.T) { }, } - crt.Sign(badPriv) + crt.Sign(cert.Curve_CURVE25519, badPriv) b, _ = crt.MarshalToPEM() certFile.Truncate(0) certFile.Seek(0, 0) @@ -129,7 +128,7 @@ func Test_verify(t *testing.T) { assert.EqualError(t, err, "certificate signature did not match") // verified cert at path - crt.Sign(caPriv) + crt.Sign(cert.Curve_CURVE25519, caPriv) b, _ = crt.MarshalToPEM() certFile.Truncate(0) certFile.Seek(0, 0) diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index c1de26722..8d0eaa1db 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -59,13 +59,8 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index e9b285e7a..5cf0a028a 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -53,18 +53,14 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } if !*configTest { ctrl.Start() + notifyReady(l) ctrl.ShutdownBlock() } diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go new file mode 100644 index 000000000..8c3dca558 --- /dev/null +++ b/cmd/nebula/notify_linux.go @@ -0,0 +1,42 @@ +package main + +import ( + "net" + "os" + "time" + + "github.com/sirupsen/logrus" +) + +// SdNotifyReady tells systemd the service is ready and dependent services can now be started +// https://www.freedesktop.org/software/systemd/man/sd_notify.html +// https://www.freedesktop.org/software/systemd/man/systemd.service.html +const SdNotifyReady = "READY=1" + +func notifyReady(l *logrus.Logger) { + sockName := os.Getenv("NOTIFY_SOCKET") + if sockName == "" { + l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + return + } + + conn, err := net.DialTimeout("unixgram", sockName, time.Second) + if err != nil { + l.WithError(err).Error("failed to connect to systemd notification socket") + return + } + defer conn.Close() + + err = conn.SetWriteDeadline(time.Now().Add(time.Second)) + if err != nil { + l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + return + } + + if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { + l.WithError(err).Error("failed to signal the systemd notification socket") + return + } + + l.Debugln("notified systemd the service is ready") +} diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go new file mode 100644 index 000000000..e7758e094 --- /dev/null +++ b/cmd/nebula/notify_notlinux.go @@ -0,0 +1,10 @@ +//go:build !linux +// +build !linux + +package main + +import "github.com/sirupsen/logrus" + +func notifyReady(_ *logrus.Logger) { + // No init service to notify +} diff --git a/config/config.go b/config/config.go index 966e905bb..1aea83273 100644 --- a/config/config.go +++ b/config/config.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "math" "os" "os/signal" "path/filepath" @@ -15,7 +15,7 @@ import ( "syscall" "time" - "github.com/imdario/mergo" + "dario.cat/mergo" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) @@ -121,6 +121,10 @@ func (c *C) HasChanged(k string) bool { // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the // original path provided to Load. The old settings are shallow copied for change detection after the reload. func (c *C) CatchHUP(ctx context.Context) { + if c.path == "" { + return + } + ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGHUP) @@ -236,6 +240,15 @@ func (c *C) GetInt(k string, d int) int { return v } +// GetUint32 will get the uint32 for k or return the default d if not found or invalid +func (c *C) GetUint32(k string, d uint32) uint32 { + r := c.GetInt(k, int(d)) + if uint64(r) > uint64(math.MaxUint32) { + return d + } + return uint32(r) +} + // GetBool will get the bool for k or return the default d if not found or invalid func (c *C) GetBool(k string, d bool) bool { r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) @@ -348,7 +361,7 @@ func (c *C) parse() error { var m map[interface{}]interface{} for _, path := range c.files { - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { return err } diff --git a/config/config_test.go b/config/config_test.go index 52bf2e479..fa9439302 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,13 +1,12 @@ package config import ( - "io/ioutil" "os" "path/filepath" "testing" "time" - "github.com/imdario/mergo" + "dario.cat/mergo" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,10 +15,10 @@ import ( func TestConfig_Load(t *testing.T) { l := test.NewLogger() - dir, err := ioutil.TempDir("", "config-test") + dir, err := os.MkdirTemp("", "config-test") // invalid yaml c := NewC(l) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge @@ -29,8 +28,8 @@ func TestConfig_Load(t *testing.T) { assert.Nil(t, err) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) - ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) assert.Nil(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ @@ -120,9 +119,9 @@ func TestConfig_HasChanged(t *testing.T) { func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) - dir, err := ioutil.TempDir("", "config-test") + dir, err := os.MkdirTemp("", "config-test") assert.Nil(t, err) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) assert.Nil(t, c.Load(dir)) @@ -131,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) { assert.False(t, c.HasChanged("outer")) assert.False(t, c.HasChanged("")) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) c.RegisterReloadCallback(func(c *C) { done <- true diff --git a/connection_manager.go b/connection_manager.go index 813542166..0b277b5c1 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -1,53 +1,80 @@ package nebula import ( + "bytes" "context" "sync" "time" + "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/udp" ) -// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet -// and something like every 10 packets we could lock, send 10, then unlock for a moment +type trafficDecision int + +const ( + doNothing trafficDecision = 0 + deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote + closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote + swapPrimary trafficDecision = 3 + migrateRelays trafficDecision = 4 + tryRehandshake trafficDecision = 5 + sendTestPacket trafficDecision = 6 +) type connectionManager struct { - hostMap *HostMap - in map[uint32]struct{} - inLock *sync.RWMutex - out map[uint32]struct{} - outLock *sync.RWMutex - TrafficTimer *LockingTimerWheel[uint32] - intf *Interface + in map[uint32]struct{} + inLock *sync.RWMutex - pendingDeletion map[uint32]int - pendingDeletionLock *sync.RWMutex - pendingDeletionTimer *LockingTimerWheel[uint32] + out map[uint32]struct{} + outLock *sync.RWMutex - checkInterval int - pendingDeletionInterval int + // relayUsed holds which relay localIndexs are in use + relayUsed map[uint32]struct{} + relayUsedLock *sync.RWMutex + + hostMap *HostMap + trafficTimer *LockingTimerWheel[uint32] + intf *Interface + pendingDeletion map[uint32]struct{} + punchy *Punchy + checkInterval time.Duration + pendingDeletionInterval time.Duration + metricsTxPunchy metrics.Counter l *logrus.Logger - // I wanted to call one matLock } -func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { +func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { + var max time.Duration + if checkInterval < pendingDeletionInterval { + max = pendingDeletionInterval + } else { + max = checkInterval + } + nc := &connectionManager{ hostMap: intf.hostMap, in: make(map[uint32]struct{}), inLock: &sync.RWMutex{}, out: make(map[uint32]struct{}), outLock: &sync.RWMutex{}, - TrafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60), + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, + trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), intf: intf, - pendingDeletion: make(map[uint32]int), - pendingDeletionLock: &sync.RWMutex{}, - pendingDeletionTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60), + pendingDeletion: make(map[uint32]struct{}), checkInterval: checkInterval, pendingDeletionInterval: pendingDeletionInterval, + punchy: punchy, + metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), l: l, } + nc.Start(ctx) return nc } @@ -74,65 +101,47 @@ func (n *connectionManager) Out(localIndex uint32) { } n.outLock.RUnlock() n.outLock.Lock() - // double check since we dropped the lock temporarily - if _, ok := n.out[localIndex]; ok { - n.outLock.Unlock() - return - } n.out[localIndex] = struct{}{} - n.AddTrafficWatch(localIndex, n.checkInterval) n.outLock.Unlock() } -func (n *connectionManager) CheckIn(localIndex uint32) bool { - n.inLock.RLock() - if _, ok := n.in[localIndex]; ok { - n.inLock.RUnlock() - return true +func (n *connectionManager) RelayUsed(localIndex uint32) { + n.relayUsedLock.RLock() + // If this already exists, return + if _, ok := n.relayUsed[localIndex]; ok { + n.relayUsedLock.RUnlock() + return } - n.inLock.RUnlock() - return false + n.relayUsedLock.RUnlock() + n.relayUsedLock.Lock() + n.relayUsed[localIndex] = struct{}{} + n.relayUsedLock.Unlock() } -func (n *connectionManager) ClearLocalIndex(localIndex uint32) { +// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and +// resets the state for this local index +func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { n.inLock.Lock() n.outLock.Lock() + _, in := n.in[localIndex] + _, out := n.out[localIndex] delete(n.in, localIndex) delete(n.out, localIndex) n.inLock.Unlock() n.outLock.Unlock() + return in, out } -func (n *connectionManager) ClearPendingDeletion(localIndex uint32) { - n.pendingDeletionLock.Lock() - delete(n.pendingDeletion, localIndex) - n.pendingDeletionLock.Unlock() -} - -func (n *connectionManager) AddPendingDeletion(localIndex uint32) { - n.pendingDeletionLock.Lock() - if _, ok := n.pendingDeletion[localIndex]; ok { - n.pendingDeletion[localIndex] += 1 - } else { - n.pendingDeletion[localIndex] = 0 - } - n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval)) - n.pendingDeletionLock.Unlock() -} - -func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool { - n.pendingDeletionLock.RLock() - if _, ok := n.pendingDeletion[localIndex]; ok { - - n.pendingDeletionLock.RUnlock() - return true +func (n *connectionManager) AddTrafficWatch(localIndex uint32) { + // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index + n.outLock.Lock() + if _, ok := n.out[localIndex]; ok { + n.outLock.Unlock() + return } - n.pendingDeletionLock.RUnlock() - return false -} - -func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) { - n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds)) + n.out[localIndex] = struct{}{} + n.trafficTimer.Add(localIndex, n.checkInterval) + n.outLock.Unlock() } func (n *connectionManager) Start(ctx context.Context) { @@ -140,6 +149,7 @@ func (n *connectionManager) Start(ctx context.Context) { } func (n *connectionManager) Run(ctx context.Context) { + //TODO: this tick should be based on the min wheel tick? Check firewall clockSource := time.NewTicker(500 * time.Millisecond) defer clockSource.Stop() @@ -151,144 +161,322 @@ func (n *connectionManager) Run(ctx context.Context) { select { case <-ctx.Done(): return + case now := <-clockSource.C: - n.HandleMonitorTick(now, p, nb, out) - n.HandleDeletionTick(now) + n.trafficTimer.Advance(now) + for { + localIndex, has := n.trafficTimer.Purge() + if !has { + break + } + + n.doTrafficCheck(localIndex, p, nb, out, now) + } } } } -func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) { - n.TrafficTimer.Advance(now) - for { - localIndex, has := n.TrafficTimer.Purge() - if !has { - break +func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { + decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) + + switch decision { + case deleteTunnel: + if n.hostMap.DeleteHostInfo(hostinfo) { + // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap + n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) } - // Check for traffic coming back in from this host. - traf := n.CheckIn(localIndex) + case closeTunnel: + n.intf.sendCloseTunnel(hostinfo) + n.intf.closeTunnel(hostinfo) - hostinfo, err := n.hostMap.QueryIndex(localIndex) - if err != nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue + case swapPrimary: + n.swapPrimary(hostinfo, primary) + + case migrateRelays: + n.migrateRelayUsed(hostinfo, primary) + + case tryRehandshake: + n.tryRehandshake(hostinfo) + + case sendTestPacket: + n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + } + + n.resetRelayTrafficCheck(hostinfo) +} + +func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { + if hostinfo != nil { + n.relayUsedLock.Lock() + defer n.relayUsedLock.Unlock() + // No need to migrate any relays, delete usage info now. + for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { + delete(n.relayUsed, idx) } + } +} + +func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { + relayFor := oldhostinfo.relayState.CopyAllRelayFor() + + for _, r := range relayFor { + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) - if n.handleInvalidCertificate(now, hostinfo) { + var index uint32 + var relayFrom iputil.VpnIp + var relayTo iputil.VpnIp + switch { + case ok && existing.State == Established: + // This relay already exists in newhostinfo, then do nothing. continue + case ok && existing.State == Requested: + // The relay exists in a Requested state; re-send the request + index = existing.LocalIndex + switch r.Type { + case TerminalType: + relayFrom = n.intf.myVpnIp + relayTo = existing.PeerIp + case ForwardingType: + relayFrom = existing.PeerIp + relayTo = newhostinfo.vpnIp + default: + // should never happen + } + case !ok: + n.relayUsedLock.RLock() + if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { + // The relay hasn't been used; don't migrate it. + n.relayUsedLock.RUnlock() + continue + } + n.relayUsedLock.RUnlock() + // The relay doesn't exist at all; create some relay state and send the request. + var err error + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + if err != nil { + n.l.WithError(err).Error("failed to migrate relay to new hostinfo") + continue + } + switch r.Type { + case TerminalType: + relayFrom = n.intf.myVpnIp + relayTo = r.PeerIp + case ForwardingType: + relayFrom = r.PeerIp + relayTo = newhostinfo.vpnIp + default: + // should never happen + } } - // If we saw an incoming packets from this ip and peer's certificate is not - // expired, just ignore. - if traf { - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") - } - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, + RelayFromIp: uint32(relayFrom), + RelayToIp: uint32(relayTo), } + msg, err := req.Marshal() + if err != nil { + n.l.WithError(err).Error("failed to marshal Control message to migrate relay") + } else { + n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) + n.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(req.RelayFromIp), + "relayTo": iputil.VpnIp(req.RelayToIp), + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnIp": newhostinfo.vpnIp}). + Info("send CreateRelayRequest") + } + } +} - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") +func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { + n.hostMap.RLock() + defer n.hostMap.RUnlock() - if hostinfo != nil && hostinfo.ConnectionState != nil { - // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) + hostinfo := n.hostMap.Indexes[localIndex] + if hostinfo == nil { + n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") + delete(n.pendingDeletion, localIndex) + return doNothing, nil, nil + } - } else { - hostinfo.logger(n.l).Debugf("Hostinfo sadness") - } - n.AddPendingDeletion(localIndex) + if n.isInvalidCertificate(now, hostinfo) { + delete(n.pendingDeletion, hostinfo.localIndexId) + return closeTunnel, hostinfo, nil } -} + primary := n.hostMap.Hosts[hostinfo.vpnIp] + mainHostInfo := true + if primary != nil && primary != hostinfo { + mainHostInfo = false + } -func (n *connectionManager) HandleDeletionTick(now time.Time) { - n.pendingDeletionTimer.Advance(now) - for { - localIndex, has := n.pendingDeletionTimer.Purge() - if !has { - break + // Check for traffic on this hostinfo + inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) + + // A hostinfo is determined alive if there is incoming traffic + if inTraffic { + decision := doNothing + if n.l.Level >= logrus.DebugLevel { + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). + Debug("Tunnel status") } + delete(n.pendingDeletion, hostinfo.localIndexId) - hostinfo, err := n.hostMap.QueryIndex(localIndex) - if err != nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue + if mainHostInfo { + decision = tryRehandshake + + } else { + if n.shouldSwapPrimary(hostinfo, primary) { + decision = swapPrimary + } else { + // migrate the relays to the primary, if in use. + decision = migrateRelays + } } - if n.handleInvalidCertificate(now, hostinfo) { - continue + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + + if !outTraffic { + // Send a punch packet to keep the NAT state alive + n.sendPunch(hostinfo) } - // If we saw an incoming packets from this ip and peer's certificate is not - // expired, just ignore. - traf := n.CheckIn(localIndex) - if traf { - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "alive", "method": "active"}). - Debug("Tunnel status") + return decision, hostinfo, primary + } + + if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { + // We have already sent a test packet and nothing was returned, this hostinfo is dead + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "dead", "method": "active"}). + Info("Tunnel status") + + delete(n.pendingDeletion, hostinfo.localIndexId) + return deleteTunnel, hostinfo, nil + } + + decision := doNothing + if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { + if !outTraffic { + // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. + // Just maintain NAT state if configured to do so. + n.sendPunch(hostinfo) + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return doNothing, nil, nil - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue } - // If it comes around on deletion wheel and hasn't resolved itself, delete - if n.checkPendingDeletion(localIndex) { - cn := "" - if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { - cn = hostinfo.ConnectionState.peerCert.Details.Name - } + if n.punchy.GetTargetEverything() { + // This is similar to the old punchy behavior with a slight optimization. + // We aren't receiving traffic but we are sending it, punch on all known + // ips in case we need to re-prime NAT state + n.sendPunch(hostinfo) + } + if n.l.Level >= logrus.DebugLevel { hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - WithField("certName", cn). - Info("Tunnel status") - - n.hostMap.DeleteHostInfo(hostinfo) + WithField("tunnelCheck", m{"state": "testing", "method": "active"}). + Debug("Tunnel status") } - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) + // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues + decision = sendTestPacket + + } else { + if n.l.Level >= logrus.DebugLevel { + hostinfo.logger(n.l).Debugf("Hostinfo sadness") + } } + + n.pendingDeletion[hostinfo.localIndexId] = struct{}{} + n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) + return decision, hostinfo, nil } -// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid -func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { - if !n.intf.disconnectInvalid { +func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { + // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. + // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. + // Let's sort this out. + + if current.vpnIp < n.intf.myVpnIp { + // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. + // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. + // The remotes vpn ip is lower than mine. I will not flip. return false } + certState := n.intf.pki.GetCertState() + return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) +} + +func (n *connectionManager) swapPrimary(current, primary *HostInfo) { + n.hostMap.Lock() + // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. + if n.hostMap.Hosts[current.vpnIp] == primary { + n.hostMap.unlockedMakePrimary(current) + } + n.hostMap.Unlock() +} + +// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and +// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid +// check and return true. +func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { return false } - valid, err := remoteCert.Verify(now, n.intf.caPool) + valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) if valid { return false } + if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { + // Block listed certificates should always be disconnected + return false + } + fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). WithField("fingerprint", fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") - // Inform the remote and close the tunnel locally - n.intf.sendCloseTunnel(hostinfo) - n.intf.closeTunnel(hostinfo) - - n.ClearLocalIndex(hostinfo.localIndexId) - n.ClearPendingDeletion(hostinfo.localIndexId) return true } + +func (n *connectionManager) sendPunch(hostinfo *HostInfo) { + if !n.punchy.GetPunch() { + // Punching is disabled + return + } + + if n.punchy.GetTargetEverything() { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { + n.metricsTxPunchy.Inc(1) + n.intf.outside.WriteTo([]byte{1}, addr) + }) + + } else if hostinfo.remote != nil { + n.metricsTxPunchy.Inc(1) + n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + } +} + +func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { + certState := n.intf.pki.GetCertState() + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { + return + } + + n.l.WithField("vpnIp", hostinfo.vpnIp). + WithField("reason", "local certificate is not current"). + Info("Re-handshaking with remote") + + n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) +} diff --git a/connection_manager_test.go b/connection_manager_test.go index 58fdbcdfe..f50bcf862 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -20,8 +21,9 @@ var vpnIp iputil.VpnIp func newTestLighthouse() *LightHouse { lh := &LightHouse{ - l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, + l: test.NewLogger(), + addrMap: map[iputil.VpnIp]*RemoteList{}, + queryChan: make(chan iputil.VpnIp, 10), } lighthouses := map[iputil.VpnIp]struct{}{} staticList := map[iputil.VpnIp]struct{}{} @@ -41,35 +43,38 @@ func Test_NewConnectionManagerTest(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, - outside: &udp.Conn{}, - certState: cs, + outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - now := time.Now() + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) - nc.HandleMonitorTick(now, p, nb, out) + // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnIp: vpnIp, @@ -77,33 +82,36 @@ func Test_NewConnectionManagerTest(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } - nc.hostMap.addHostInfo(hostinfo, ifce) + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) + nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // Move ahead 5s. Nothing should happen - next_tick := now.Add(5 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // Move ahead 6s. We haven't heard back - next_tick = now.Add(6 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // This host should now be up for deletion + assert.Contains(t, nc.out, hostinfo.localIndexId) + + // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + + // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // Move ahead some more - next_tick = now.Add(45 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // The host should be evicted + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + + // Do a final traffic check tick, the host should now be removed + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) @@ -117,35 +125,38 @@ func Test_NewConnectionManagerTest2(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, - outside: &udp.Conn{}, - certState: cs, + outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - now := time.Now() + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) - nc.HandleMonitorTick(now, p, nb, out) + // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnIp: vpnIp, @@ -153,37 +164,41 @@ func Test_NewConnectionManagerTest2(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } - nc.hostMap.addHostInfo(hostinfo, ifce) + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, vpnIp) - assert.Contains(t, nc.hostMap.Hosts, vpnIp) - // Move ahead 5s. Nothing should happen - next_tick := now.Add(5 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // Move ahead 6s. We haven't heard back - next_tick = now.Add(6 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // This host should now be up for deletion + nc.In(hostinfo.localIndexId) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + + // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + + // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, vpnIp) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // We heard back this time + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + + // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) - // Move ahead some more - next_tick = now.Add(45 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // The host should not be evicted + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) } // Check if we can disconnect the peer. @@ -199,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) @@ -212,7 +228,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCA, }, } - caCert.Sign(privCA) + + assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) ncp := &cert.NebulaCAPool{ CAs: cert.NewCAPool().CAs, } @@ -231,52 +248,58 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { Issuer: "ca", }, } - peerCert.Sign(privCA) + assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ - hostMap: hostMap, - inside: &test.NoopTun{}, - outside: &udp.Conn{}, - certState: cs, - firewall: &Firewall{}, - lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), - l: l, - disconnectInvalid: true, - caPool: ncp, + hostMap: hostMap, + inside: &test.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{}, + lightHouse: lh, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + pki: &PKI{}, } + ifce.pki.cs.Store(cs) + ifce.pki.caPool.Store(ncp) + ifce.disconnectInvalid.Store(true) // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) ifce.connectionManager = nc - hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) - hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - peerCert: &peerCert, - H: &noise.HandshakeState{}, + + hostinfo := &HostInfo{ + vpnIp: vpnIp, + ConnectionState: &ConnectionState{ + myCert: &cert.NebulaCertificate{}, + peerCert: &peerCert, + H: &noise.HandshakeState{}, + }, } + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // Move ahead 45s. // Check if to disconnect with invalid certificate. // Should be alive. nextTick := now.Add(45 * time.Second) - destroyed := nc.handleInvalidCertificate(nextTick, hostinfo) - assert.False(t, destroyed) + invalid := nc.isInvalidCertificate(nextTick, hostinfo) + assert.False(t, invalid) // Move ahead 61s. // Check if to disconnect with invalid certificate. // Should be disconnected. nextTick = now.Add(61 * time.Second) - destroyed = nc.handleInvalidCertificate(nextTick, hostinfo) - assert.True(t, destroyed) + invalid = nc.isInvalidCertificate(nextTick, hostinfo) + assert.True(t, invalid) } diff --git a/connection_state.go b/connection_state.go index 6bbb02f72..8ef8b3a24 100644 --- a/connection_state.go +++ b/connection_state.go @@ -9,6 +9,7 @@ import ( "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 @@ -17,24 +18,34 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - certState *CertState + myCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate initiator bool messageCounter atomic.Uint64 window *Bits - queueLock sync.Mutex writeLock sync.Mutex - ready bool } -func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { - cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) - if f.cipher == "chachapoly" { - cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) +func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { + var dhFunc noise.DHFunc + switch certState.Certificate.Details.Curve { + case cert.Curve_CURVE25519: + dhFunc = noise.DH25519 + case cert.Curve_P256: + dhFunc = noiseutil.DHP256 + default: + l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) + return nil + } + + var cs noise.CipherSuite + if cipher == "chachapoly" { + cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + } else { + cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - curCertState := f.certState - static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} + static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss @@ -59,8 +70,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern H: hs, initiator: initiator, window: b, - ready: false, - certState: curCertState, + myCert: certState.Certificate, } return ci @@ -71,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "certificate": cs.peerCert, "initiator": cs.initiator, "message_counter": cs.messageCounter.Load(), - "ready": cs.ready, }) } diff --git a/control.go b/control.go index adc2a4846..c227b207b 100644 --- a/control.go +++ b/control.go @@ -11,19 +11,31 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc +type controlEach func(h *HostInfo) + +type controlHostLister interface { + QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + ForEachIndex(each controlEach) + ForEachVpnIp(each controlEach) + GetPreferredRanges() []*net.IPNet +} + type Control struct { - f *Interface - l *logrus.Logger - cancel context.CancelFunc - sshStart func() - statsStart func() - dnsStart func() + f *Interface + l *logrus.Logger + ctx context.Context + cancel context.CancelFunc + sshStart func() + statsStart func() + dnsStart func() + lighthouseStart func() } type ControlHostInfo struct { @@ -31,7 +43,6 @@ type ControlHostInfo struct { LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []*udp.Addr `json:"remoteAddrs"` - CachedPackets int `json:"cachedPackets"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` CurrentRemote *udp.Addr `json:"currentRemote"` @@ -54,12 +65,19 @@ func (c *Control) Start() { if c.dnsStart != nil { go c.dnsStart() } + if c.lighthouseStart != nil { + c.lighthouseStart() + } // Start reading packets. c.f.run() } -// Stop signals nebula to shutdown, returns after the shutdown is complete +func (c *Control) Context() context.Context { + return c.ctx +} + +// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete func (c *Control) Stop() { // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. @@ -74,7 +92,7 @@ func (c *Control) Stop() { // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled func (c *Control) ShutdownBlock() { - sigChan := make(chan os.Signal) + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT) @@ -89,55 +107,64 @@ func (c *Control) RebindUDPServer() { _ = c.f.outside.Rebind() // Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0 - c.f.lightHouse.SendUpdate(c.f) + c.f.lightHouse.SendUpdate() // Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes c.f.rebindCount++ } -// ListHostmap returns details about the actual or pending (handshaking) hostmap -func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { +// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip +func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { + if pendingMap { + return listHostMapHosts(c.f.handshakeManager) + } else { + return listHostMapHosts(c.f.hostMap) + } +} + +// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id +func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMap(c.f.handshakeManager.pendingHostMap) + return listHostMapIndexes(c.f.handshakeManager) } else { - return listHostMap(c.f.hostMap) + return listHostMapIndexes(c.f.hostMap) } } // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { - var hm *HostMap + var hl controlHostLister if pending { - hm = c.f.handshakeManager.pendingHostMap + hl = c.f.handshakeManager } else { - hm = c.f.hostMap + hl = c.f.hostMap } - h, err := hm.QueryVpnIp(vpnIp) - if err != nil { + h := hl.QueryVpnIp(vpnIp) + if h == nil { return nil } - ch := copyHostInfo(h, c.f.hostMap.preferredRanges) + ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges()) return &ch } // SetRemoteForTunnel forces a tunnel to use a specific remote func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return nil } hostInfo.SetRemote(addr.Copy()) - ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges) + ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return false } @@ -189,7 +216,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { hostInfos := []*HostInfo{} // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() - for _, relayHost := range c.f.hostMap.Hosts { + for _, relayHost := range c.f.hostMap.Indexes { if _, ok := relayingHosts[relayHost.vpnIp]; !ok { hostInfos = append(hostInfos, relayHost) } @@ -205,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { return } +func (c *Control) Device() overlay.Device { + return c.f.inside +} + func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi := ControlHostInfo{ @@ -212,7 +243,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), - CachedPackets: len(h.packetStore), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), } @@ -232,15 +262,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { return chi } -func listHostMap(hm *HostMap) []ControlHostInfo { - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Hosts)) - i := 0 - for _, v := range hm.Hosts { - hosts[i] = copyHostInfo(v, hm.preferredRanges) - i++ - } - hm.RUnlock() +func listHostMapHosts(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) + return hosts +} +func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachIndex(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) return hosts } diff --git a/control_test.go b/control_test.go index ec469b469..c64a3a4b7 100644 --- a/control_test.go +++ b/control_test.go @@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) + hm := newHostMap(l, &net.IPNet{}) + hm.preferredRanges.Store(&[]*net.IPNet{}) + remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ @@ -47,10 +49,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Signature: []byte{1, 2, 1, 2, 1, 3}, } - remotes := NewRemoteList() + remotes := NewRemoteList(nil) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) - hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -64,9 +66,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) - hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -80,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) c := Control{ f: &Interface{ @@ -96,7 +98,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []*udp.Addr{remote2, remote1}, - CachedPackets: 0, Cert: crt.Copy(), MessageCounter: 0, CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), @@ -105,7 +106,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet diff --git a/control_tester.go b/control_tester.go index 4fa0763d4..b786ba383 100644 --- a/control_tester.go +++ b/control_tester.go @@ -21,7 +21,7 @@ import ( func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { - p := c.f.outside.Get(true) + p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } @@ -37,7 +37,7 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { - p := c.f.outside.Get(true) + p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } @@ -90,11 +90,11 @@ func (c *Control) GetFromTun(block bool) []byte { // GetFromUDP will pull a udp packet off the udp side of nebula func (c *Control) GetFromUDP(block bool) *udp.Packet { - return c.f.outside.Get(block) + return c.f.outside.(*udp.TesterConn).Get(block) } func (c *Control) GetUDPTxChan() <-chan *udp.Packet { - return c.f.outside.TxPackets + return c.f.outside.(*udp.TesterConn).TxPackets } func (c *Control) GetTunTxChan() <-chan []byte { @@ -103,7 +103,7 @@ func (c *Control) GetTunTxChan() <-chan []byte { // InjectUDPPacket will inject a packet into the udp side of nebula func (c *Control) InjectUDPPacket(p *udp.Packet) { - c.f.outside.Send(p) + c.f.outside.(*udp.TesterConn).Send(p) } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol @@ -143,16 +143,16 @@ func (c *Control) GetVpnIp() iputil.VpnIp { } func (c *Control) GetUDPAddr() string { - return c.f.outside.Addr.String() + return c.f.outside.(*udp.TesterConn).Addr.String() } func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)] - if !ok { + hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) + if hostinfo == nil { return false } - c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + c.f.handshakeManager.DeleteHostInfo(hostinfo) return true } @@ -161,5 +161,9 @@ func (c *Control) GetHostmap() *HostMap { } func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.certState.certificate + return c.f.pki.GetCertState().Certificate +} + +func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { + c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dist/arch/nebula.service b/dist/arch/nebula.service index 7e5335aa8..831c71a53 100644 --- a/dist/arch/nebula.service +++ b/dist/arch/nebula.service @@ -4,6 +4,8 @@ Wants=basic.target network-online.target nss-lookup.target time-sync.target After=basic.target network.target network-online.target [Service] +Type=notify +NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml diff --git a/dist/fedora/nebula.service b/dist/fedora/nebula.service deleted file mode 100644 index 21a99c558..000000000 --- a/dist/fedora/nebula.service +++ /dev/null @@ -1,14 +0,0 @@ -[Unit] -Description=Nebula overlay networking tool -Wants=basic.target network-online.target nss-lookup.target time-sync.target -After=basic.target network.target network-online.target -Before=sshd.service - -[Service] -SyslogIdentifier=nebula -ExecReload=/bin/kill -HUP $MAINPID -ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml -Restart=always - -[Install] -WantedBy=multi-user.target diff --git a/dns_server.go b/dns_server.go index 19bc5ced7..70ec0e075 100644 --- a/dns_server.go +++ b/dns_server.go @@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } iip := iputil.Ip2VpnIp(ip) - hostinfo, err := d.hostMap.QueryVpnIp(iip) - if err != nil { + hostinfo := d.hostMap.QueryVpnIp(iip) + if hostinfo == nil { return "" } q := hostinfo.GetCert() @@ -129,7 +129,12 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { } func getDnsServerAddr(c *config.C) string { - return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) + dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) + // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. + if dnsHost == "[::]" { + dnsHost = "::" + } + return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } func startDns(l *logrus.Logger, c *config.C) { diff --git a/dns_server_test.go b/dns_server_test.go index 830dc8a84..69f6ae84f 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/miekg/dns" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { @@ -17,3 +19,40 @@ func TestParsequery(t *testing.T) { //parseQuery(m) } + +func Test_getDnsServerAddr(t *testing.T) { + c := config.NewC(nil) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "0.0.0.0", + "port": "1", + }, + } + assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "::", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "[::]", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + + // Make sure whitespace doesn't mess us up + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "[::] ", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index bfde43ebe..59f1d0e52 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,25 +4,28 @@ package e2e import ( + "fmt" "net" "testing" "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Start the servers myControl.Start() @@ -32,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -41,19 +44,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -74,16 +77,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -92,20 +95,20 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -117,7 +120,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -136,18 +139,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -157,14 +160,17 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Stop() } -func Test_Case1_Stage1Race(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) +func TestStage1Race(t *testing.T) { + // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow + // But will eventually collapse down to a single tunnel + + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -175,8 +181,8 @@ func Test_Case1_Stage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -185,44 +191,165 @@ func Test_Case1_Stage1Race(t *testing.T) { r.Log("Now inject both stage 1 handshake packets") r.InjectUDPPacket(theirControl, myControl, theirHsForMe) r.InjectUDPPacket(myControl, theirControl, myHsForThem) - //TODO: they should win, grab their index for me and make sure I use it in the end. - r.Log("They should not have a stage 2 (won the race) but I should send one") - r.InjectUDPPacket(myControl, theirControl, myControl.GetFromUDP(true)) + r.Log("Route until they receive a message packet") + myCachedPacket := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) - r.Log("Route for me until I send a message packet to them") - r.RouteForAllUntilAfterMsgTypeTo(theirControl, header.Message, header.MessageNone) + r.Log("Their cached packet should be received by me") + theirCachedPacket := r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) - t.Log("My cached packet should be received by them") - myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + r.Log("Do a bidirectional tunnel test") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - t.Log("Route for them until I send a message packet to me") - theirControl.WaitForType(1, 0, myControl) + myHostmapHosts := myControl.ListHostmapHosts(false) + myHostmapIndexes := myControl.ListHostmapIndexes(false) + theirHostmapHosts := theirControl.ListHostmapHosts(false) + theirHostmapIndexes := theirControl.ListHostmapIndexes(false) - t.Log("Their cached packet should be received by me") - theirCachedPacket := myControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80) + // We should have two tunnels on both sides + assert.Len(t, myHostmapHosts, 1) + assert.Len(t, theirHostmapHosts, 1) + assert.Len(t, myHostmapIndexes, 2) + assert.Len(t, theirHostmapIndexes, 2) - t.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Spin until connection manager tears down a tunnel") + + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() - //TODO: assert hostmaps +} + +func TestUncleanShutdownRaceLoser(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + r.Log("Trigger a handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + + r.Log("Nuke my hostmap") + myHostmap := myControl.GetHostmap() + myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Indexes = map[uint32]*nebula.HostInfo{} + myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} + + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + p = r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + + r.Log("Assert the tunnel works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.Log("Wait for the dead index to go away") + start := len(theirControl.GetHostmap().Indexes) + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + if len(theirControl.GetHostmap().Indexes) < start { + break + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) +} + +func TestUncleanShutdownRaceWinner(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + r.Log("Trigger a handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + r.Log("Nuke my hostmap") + theirHostmap := theirControl.GetHostmap() + theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} + theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} + + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + r.RenderHostmaps("Derp hostmaps", myControl, theirControl) + + r.Log("Assert the tunnel works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.Log("Wait for the dead index to go away") + start := len(myControl.GetHostmap().Indexes) + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + if len(myControl.GetHostmap().Indexes) < start { + break + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) } func TestRelays(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIp, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIp, relayUdpAddr) - myControl.InjectRelays(theirVpnIp, []net.IP{relayVpnIp}) - relayControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -234,12 +361,616 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } +func TestStage1RaceRelays(t *testing.T) { + //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + r.Log("Get a tunnel between me and relay") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + + r.Log("Get a tunnel between them and relay") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + + r.Log("Trigger a handshake from both them and me via relay to them and me") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + + r.Log("Wait for a packet from them to me") + p := r.RouteForAllUntilTxTun(myControl) + _ = p + + r.FlushAll() + + myControl.Stop() + theirControl.Stop() + relayControl.Stop() + // + ////TODO: assert hostmaps +} + +func TestStage1RaceRelays2(t *testing.T) { + //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + l := NewTestLogger() + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + r.Log("Get a tunnel between me and relay") + l.Info("Get a tunnel between me and relay") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + + r.Log("Get a tunnel between them and relay") + l.Info("Get a tunnel between them and relay") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + + r.Log("Trigger a handshake from both them and me via relay to them and me") + l.Info("Trigger a handshake from both them and me via relay to them and me") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + + //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) + //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) + + r.Log("Wait for a packet from them to me") + l.Info("Wait for a packet from them to me; myControl") + r.RouteForAllUntilTxTun(myControl) + l.Info("Wait for a packet from them to me; theirControl") + r.RouteForAllUntilTxTun(theirControl) + + r.Log("Assert the tunnel works") + l.Info("Assert the tunnel works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + + t.Log("Wait until we remove extra tunnels") + l.Info("Wait until we remove extra tunnels") + l.WithFields( + logrus.Fields{ + "myControl": len(myControl.GetHostmap().Indexes), + "theirControl": len(theirControl.GetHostmap().Indexes), + "relayControl": len(relayControl.GetHostmap().Indexes), + }).Info("Waiting for hostinfos to be removed...") + hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) + retries := 60 + for hostInfos > 6 && retries > 0 { + hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) + l.WithFields( + logrus.Fields{ + "myControl": len(myControl.GetHostmap().Indexes), + "theirControl": len(theirControl.GetHostmap().Indexes), + "relayControl": len(relayControl.GetHostmap().Indexes), + }).Info("Waiting for hostinfos to be removed...") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + retries-- + } + + r.Log("Assert the tunnel works") + l.Info("Assert the tunnel works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + + myControl.Stop() + theirControl.Stop() + relayControl.Stop() + + // + ////TODO: assert hostmaps +} +func TestRehandshakingRelays(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + + // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, + // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. + r.Log("Renew relay certificate and spin until me and them sees it") + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + relayConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(relayConfig.Settings) + assert.NoError(t, err) + relayConfig.ReloadConfigString(string(rc)) + + for { + r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between my and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + for { + r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between their and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + // We should have two hostinfos on all sides + for len(myControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("myControl hostinfos got cleaned up!") + for len(theirControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("theirControl hostinfos got cleaned up!") + for len(relayControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("relayControl hostinfos got cleaned up!") +} + +func TestRehandshakingRelaysPrimary(t *testing.T) { + // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + + // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, + // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. + r.Log("Renew relay certificate and spin until me and them sees it") + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + relayConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(relayConfig.Settings) + assert.NoError(t, err) + relayConfig.ReloadConfigString(string(rc)) + + for { + r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between my and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + for { + r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between their and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + // We should have two hostinfos on all sides + for len(myControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("myControl hostinfos got cleaned up!") + for len(theirControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("theirControl hostinfos got cleaned up!") + for len(relayControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("relayControl hostinfos got cleaned up!") +} + +func TestRehandshaking(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + + // Put their info in our lighthouse and vice versa + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Stand up a tunnel between me and them") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Renew my certificate and spin until their sees it") + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + myConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(myConfig.Settings) + assert.NoError(t, err) + myConfig.ReloadConfigString(string(rc)) + + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + break + } + + time.Sleep(time.Second) + } + + // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly + rc, err = yaml.Marshal(theirConfig.Settings) + assert.NoError(t, err) + var theirNewConfig m + assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall["inbound"] = []m{{ + "proto": "any", + "port": "any", + "group": "new group", + }} + rc, err = yaml.Marshal(theirNewConfig) + assert.NoError(t, err) + theirConfig.ReloadConfigString(string(rc)) + + r.Log("Spin until there is only 1 tunnel") + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // Make sure the correct tunnel won + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assert.Contains(t, c.Cert.Details.Groups, "new group") + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestRehandshakingLoser(t *testing.T) { + // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel + // Should be the one with the new certificate + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + + // Put their info in our lighthouse and vice versa + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Stand up a tunnel between me and them") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Renew their certificate and spin until mine sees it") + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + theirConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(theirNextPEM), + "key": string(theirNextPrivKey), + } + rc, err := yaml.Marshal(theirConfig.Settings) + assert.NoError(t, err) + theirConfig.ReloadConfigString(string(rc)) + + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + + _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] + if theirNewGroup { + break + } + + time.Sleep(time.Second) + } + + // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly + rc, err = yaml.Marshal(myConfig.Settings) + assert.NoError(t, err) + var myNewConfig m + assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall["inbound"] = []m{{ + "proto": "any", + "port": "any", + "group": "their new group", + }} + rc, err = yaml.Marshal(myNewConfig) + assert.NoError(t, err) + myConfig.ReloadConfigString(string(rc)) + + r.Log("Spin until there is only 1 tunnel") + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // Make sure the correct tunnel won + theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + myControl.Stop() + theirControl.Stop() +} + +func TestRaceRegression(t *testing.T) { + // This test forces stage 1, stage 2, stage 1 to be received by me from them + // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which + // caused a cross-linked hostinfo + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + //them rx stage:1 initiatorIndex=642843150 responderIndex=0 + //me rx stage:1 initiatorIndex=120607833 responderIndex=0 + //them rx stage:1 initiatorIndex=642843150 responderIndex=0 + //me rx stage:2 initiatorIndex=642843150 responderIndex=3701775874 + //me rx stage:1 initiatorIndex=120607833 responderIndex=0 + //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 + + t.Log("Start both handshakes") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + + t.Log("Get both stage 1") + myStage1ForThem := myControl.GetFromUDP(true) + theirStage1ForMe := theirControl.GetFromUDP(true) + + t.Log("Inject them in a special way") + theirControl.InjectUDPPacket(myStage1ForThem) + myControl.InjectUDPPacket(theirStage1ForMe) + theirControl.InjectUDPPacket(myStage1ForThem) + + //TODO: ensure stage 2 + t.Log("Get both stage 2") + myStage2ForThem := myControl.GetFromUDP(true) + theirStage2ForMe := theirControl.GetFromUDP(true) + + t.Log("Inject them in a special way again") + myControl.InjectUDPPacket(theirStage2ForMe) + myControl.InjectUDPPacket(theirStage1ForMe) + theirControl.InjectUDPPacket(myStage2ForThem) + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Flush the packets") + r.RouteForAllUntilTxTun(myControl) + r.RouteForAllUntilTxTun(theirControl) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + t.Log("Make sure the tunnel still works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +//TODO: test +// Race winner renews and handshakes +// Race loser renews and handshakes +// Does race winner repin the cert to old? //TODO: add a test with many lies diff --git a/e2e/helpers.go b/e2e/helpers.go new file mode 100644 index 000000000..13146ab71 --- /dev/null +++ b/e2e/helpers.go @@ -0,0 +1,118 @@ +package e2e + +import ( + "crypto/rand" + "io" + "net" + "time" + + "github.com/slackhq/nebula/cert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + InvertedGroups: make(map[string]struct{}), + }, + } + + if len(ips) > 0 { + nc.Details.Ips = ips + } + + if len(subnets) > 0 { + nc.Details.Subnets = subnets + } + + if len(groups) > 0 { + nc.Details.Groups = groups + } + + err = nc.Sign(cert.Curve_CURVE25519, priv) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + issuer, err := ca.Sha256Sum() + if err != nil { + panic(err) + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + pub, rawPriv := x25519Keypair() + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: name, + Ips: []*net.IPNet{ip}, + Subnets: subnets, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + InvertedGroups: make(map[string]struct{}), + }, + } + + err = nc.Sign(ca.Details.Curve, key) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem +} + +func x25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index a378beafc..b05c84a22 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "crypto/rand" "fmt" "io" "net" @@ -12,9 +11,9 @@ import ( "testing" "time" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/imdario/mergo" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" @@ -22,15 +21,13 @@ import ( "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" "gopkg.in/yaml.v2" ) type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, net.IP, *net.UDPAddr) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { l := NewTestLogger() vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} @@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u IP: udpIp, Port: 4242, } - _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { @@ -77,6 +74,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), "level": l.Level.String(), }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, } if overrides != nil { @@ -101,113 +102,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet.IP, &udpAddr -} - -// newTestCaCert will generate a CA cert -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(priv) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, priv, pem -} - -// newTestCert will generate a signed certificate with the provided details. -// Expiry times are defaulted if you do not pass them in -func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - pub, rawPriv := x25519Keypair() - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(key) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey + return control, vpnIpNet, &udpAddr, c } type doneCb func() @@ -231,12 +126,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) - bPacket := r.RouteUntilTxTun(controlB, controlA) + bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) - aPacket := r.RouteUntilTxTun(controlA, controlB) + aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 948281ab5..120be6960 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,9 +5,11 @@ package router import ( "fmt" + "sort" "strings" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -61,10 +63,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() + hm.RLock() + defer hm.RUnlock() // Draw the vpn to index nodes r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName) - for vpnIp, hi := range hm.Hosts { + hosts := sortedHosts(hm.Hosts) + for _, vpnIp := range hosts { + hi := hm.Hosts[vpnIp] r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp) lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex())) @@ -91,11 +97,15 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { // Draw the local index to relay or remote index nodes r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName) - for idx, hi := range hm.Indexes { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) - remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") - globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) - _ = hi + indexes := sortedIndexes(hm.Indexes) + for _, idx := range indexes { + hi, ok := hm.Indexes[idx] + if ok { + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) + remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) + _ = hi + } } r += "\t\tend\n" @@ -107,3 +117,29 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { r += "\tend\n" return r, globalLines } + +func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { + keys := make([]iputil.VpnIp, 0, len(hosts)) + for key := range hosts { + keys = append(keys, key) + } + + sort.SliceStable(keys, func(i, j int) bool { + return keys[i] > keys[j] + }) + + return keys +} + +func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 { + keys := make([]uint32, 0, len(indexes)) + for key := range indexes { + keys = append(keys, key) + } + + sort.SliceStable(keys, func(i, j int) bool { + return keys[i] > keys[j] + }) + + return keys +} diff --git a/e2e/router/router.go b/e2e/router/router.go index aa56db8e2..730853a99 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "reflect" + "sort" "strconv" "strings" "sync" @@ -22,6 +23,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/maps" ) type R struct { @@ -150,6 +152,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { case <-ctx.Done(): return case <-clockSource.C: + r.renderHostmaps("clock tick") r.renderFlow() } } @@ -212,7 +215,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "#58;", 1) + sanAddr := strings.Replace(addr, ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -220,11 +223,16 @@ func (r *R) renderFlow() { ) } + if len(participantsVals) > 2 { + // Get the first and last participantVals for notes + participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]} + } + // Print packets h := &header.H{} for _, e := range r.flow { if e.packet == nil { - fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) + //fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) continue } @@ -244,9 +252,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -294,6 +302,28 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { }) } +func (r *R) renderHostmaps(title string) { + c := maps.Values(r.controls) + sort.SliceStable(c, func(i, j int) bool { + return c[i].GetVpnIp() > c[j].GetVpnIp() + }) + + s := renderHostmaps(c...) + if len(r.additionalGraphs) > 0 { + lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1] + if lastGraph.content == s { + // Ignore this rendering if it matches the last rendering added + // This is useful if you want to track rendering changes + return + } + } + + r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{ + title: title, + content: s, + }) +} + // InjectFlow can be used to record packet flow if the test is handling the routing on its own. // The packet is assumed to have been received func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) { @@ -332,6 +362,8 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool return nil } + r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow))) + if len(r.ignoreFlows) > 0 { var h header.H err := h.Parse(p.Data) @@ -726,8 +758,8 @@ func (r *R) formatUdpPacket(p *packet) string { data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "#58;", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(from, ":", "-", 1), + strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()), diff --git a/examples/config.yml b/examples/config.yml index f214bf746..9064c2300 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -11,7 +11,7 @@ pki: #blocklist: # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. - #disconnect_invalid: false + #disconnect_invalid: true # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. @@ -21,6 +21,19 @@ pki: static_host_map: "192.168.100.1": ["100.64.22.11:4242"] +# The static_map config stanza can be used to configure how the static_host_map behaves. +#static_map: + # cadence determines how frequently DNS is re-queried for updated IP addresses when a static_host_map entry contains + # a DNS name. + #cadence: 30s + + # network determines the type of IP addresses to ask the DNS server for. The default is "ip4" because nodes typically + # do not know their public IPv4 address. Connecting to the Lighthouse via IPv4 allows the Lighthouse to detect the + # public address. Other valid options are "ip6" and "ip" (returns both.) + #network: ip4 + + # lookup_timeout is the DNS query timeout. + #lookup_timeout: 250ms lighthouse: # am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes @@ -91,10 +104,23 @@ lighthouse: #- "1.1.1.1:4242" #- "1.2.3.4:0" # port will be replaced with the real listening port + # EXPERIMENTAL: This option may change or disappear in the future. + # This setting allows us to "guess" what the remote might be for a host + # while we wait for the lighthouse response. + #calculated_remotes: + # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add + # the calculated IP as an initial remote (while we wait for the response + # from the lighthouse). Both CIDRs must have the same mask size. + # For example, Nebula IP 10.0.10.123 will have a calculated remote of + # 192.168.1.123 + #10.0.10.0/24: + #- mask: 192.168.1.0/24 + # port: 4242 + # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "[::]" + # To listen on both any ipv4 and ipv6 use "::" host: 0.0.0.0 port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) @@ -129,9 +155,12 @@ punchy: # Default is false #respond: true - # delays a punch response for misbehaving NATs, default is 1 second, respond must be true to take effect + # delays a punch response for misbehaving NATs, default is 1 second. #delay: 1s + # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. + #respond_delay: 5s + # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes # IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously! #cipher: aes @@ -142,7 +171,8 @@ punchy: # and has been deprecated for "preferred_ranges" #preferred_ranges: ["172.16.0.0/24"] -# sshd can expose informational and administrative functions via ssh this is a +# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative +# functions, and allows manual tweaking of various network settings when debugging or testing. #sshd: # Toggles the feature #enabled: true @@ -178,7 +208,7 @@ tun: disabled: false # Name of the device. If not set, a default will be chosen by the OS. # For macOS: if set, must be in the form `utun[0-9]+`. - # For FreeBSD: Required to be set, must be in the form `tun[0-9]+`. + # For NetBSD: Required to be set, must be in the form `tun[0-9]+` dev: nebula1 # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert drop_local_broadcast: false @@ -188,26 +218,36 @@ tun: tx_queue: 500 # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic mtu: 1300 + # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here routes: #- mtu: 8800 # route: 10.0.0.0/16 + # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate - # `mtu` will default to tun mtu if this option is not specified - # `metric` will default to 0 if this option is not specified + # `mtu`: will default to tun mtu if this option is not specified + # `metric`: will default to 0 if this option is not specified + # `install`: will default to true, controls whether this route is installed in the systems routing table. unsafe_routes: #- route: 172.16.1.0/24 # via: 192.168.100.99 # mtu: 1300 # metric: 100 + # install: true + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of + # in nebula configuration files. Default false, not reloadable. + #use_system_route_table: false # TODO # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info + # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. + #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug logging is also CPU intensive and will decrease performance overall. + # Only enable debug logging while actively investigating an issue. level: info # json or text formats currently available. Default is text format: text @@ -252,6 +292,10 @@ logging: # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 + + # query_buffer is the size of the buffer channel for querying lighthouses + #query_buffer: 64 + # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 @@ -259,6 +303,22 @@ logging: # Nebula security group configuration firewall: + # Action to take when a packet is not allowed by the firewall rules. + # Can be one of: + # `drop` (default): silently drop the packet. + # `reject`: send a reject reply. + # - For TCP, this will be a RST "Connection Reset" packet. + # - For other protocols, this will be an ICMP port unreachable packet. + outbound_action: drop + inbound_action: drop + + # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false. + # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an + # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless + # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr` + # if the intention is to allow traffic to flow to an unsafe route. + #default_local_cidr_any: false + conntrack: tcp_timeout: 12m udp_timeout: 3m @@ -266,14 +326,17 @@ firewall: # The firewall is default deny. There is no way to write a deny rule. # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR - # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) + # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a CIDR, `0.0.0.0/0` is any. + # cidr: a remote CIDR, `0.0.0.0/0` is any. + # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. + # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate + # if `default_local_cidr_any` is false, otherwise its `any`. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum @@ -295,3 +358,10 @@ firewall: groups: - laptop - home + + # Expose a subnet (unsafe route) to hosts with the group remote_client + # This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate + - port: 8080 + proto: tcp + group: remote_client + local_cidr: 192.168.100.1/24 diff --git a/examples/go_service/main.go b/examples/go_service/main.go new file mode 100644 index 000000000..f46273acf --- /dev/null +++ b/examples/go_service/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "bufio" + "fmt" + "log" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("%+v", err) + } +} + +func run() error { + configStr := ` +tun: + user: true + +static_host_map: + '192.168.100.1': ['localhost:4242'] + +listen: + host: 0.0.0.0 + port: 4241 + +lighthouse: + am_lighthouse: false + interval: 60 + hosts: + - '192.168.100.1' + +firewall: + outbound: + # Allow all outbound traffic from this node + - port: any + proto: any + host: any + + inbound: + # Allow icmp between any nebula hosts + - port: any + proto: icmp + host: any + - port: any + proto: any + host: any + +pki: + ca: /home/rice/Developer/nebula-config/ca.crt + cert: /home/rice/Developer/nebula-config/app.crt + key: /home/rice/Developer/nebula-config/app.key +` + var config config.C + if err := config.LoadString(configStr); err != nil { + return err + } + service, err := service.New(&config) + if err != nil { + return err + } + + ln, err := service.Listen("tcp", ":1234") + if err != nil { + return err + } + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("accept error: %s", err) + break + } + defer conn.Close() + + log.Printf("got connection") + + conn.Write([]byte("hello world\n")) + + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + message := scanner.Text() + fmt.Fprintf(conn, "echo: %q\n", message) + log.Printf("got message %q", message) + } + + if err := scanner.Err(); err != nil { + log.Printf("scanner error: %s", err) + break + } + } + + service.Close() + if err := service.Wait(); err != nil { + return err + } + return nil +} diff --git a/examples/service_scripts/nebula.plist b/examples/service_scripts/nebula.plist new file mode 100644 index 000000000..c423cfc72 --- /dev/null +++ b/examples/service_scripts/nebula.plist @@ -0,0 +1,34 @@ + + + + + KeepAlive + + Label + net.defined.nebula + WorkingDirectory + /Users/{username}/.local/bin/nebula + LimitLoadToSessionType + + Aqua + Background + LoginWindow + StandardIO + System + + ProgramArguments + + ./nebula + -config + ./config.yml + + RunAtLoad + + StandardErrorPath + ./nebula.log + StandardOutPath + ./nebula.log + UserName + root + + \ No newline at end of file diff --git a/examples/service_scripts/nebula.service b/examples/service_scripts/nebula.service index fd7a06710..ab5218f8d 100644 --- a/examples/service_scripts/nebula.service +++ b/examples/service_scripts/nebula.service @@ -5,6 +5,8 @@ After=basic.target network.target network-online.target Before=sshd.service [Service] +Type=notify +NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml diff --git a/firewall.go b/firewall.go index 9fd75fc37..3e760feb3 100644 --- a/firewall.go +++ b/firewall.go @@ -2,10 +2,10 @@ package nebula import ( "crypto/sha256" - "encoding/binary" "encoding/hex" "errors" "fmt" + "hash/fnv" "net" "reflect" "strconv" @@ -21,17 +21,12 @@ import ( "github.com/slackhq/nebula/firewall" ) -const tcpACK = 0x10 -const tcpFIN = 0x01 - type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error } type conn struct { Expires time.Time // Time when this conntrack entry will expire - Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set - Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack // record why the original connection passed the firewall, so we can re-validate // after ruleset changes. Note, rulesVersion is a uint16 so that these two @@ -47,6 +42,9 @@ type Firewall struct { InRules *FirewallTable OutRules *FirewallTable + InSendReject bool + OutSendReject bool + //TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better // https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt TCPTimeout time.Duration //linux: 5 days max @@ -54,15 +52,16 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4 + localIps *cidr.Tree4[struct{}] + assignedCIDR *net.IPNet + hasSubnets bool rules string rulesVersion uint16 - trackTCPRTT bool - metricTCPRTT metrics.Histogram - incomingMetrics firewallMetrics - outgoingMetrics firewallMetrics + defaultLocalCIDRAny bool + incomingMetrics firewallMetrics + outgoingMetrics firewallMetrics l *logrus.Logger } @@ -80,6 +79,8 @@ type FirewallConntrack struct { TimerWheel *TimerWheel[firewall.Packet] } +// FirewallTable is the entry point for a rule, the evaluation order is: +// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { TCP firewallPort UDP firewallPort @@ -104,16 +105,26 @@ type FirewallCA struct { type FirewallRule struct { // Any makes Hosts, Groups, and CIDR irrelevant - Any bool - Hosts map[string]struct{} - Groups [][]string - CIDR *cidr.Tree4 + Any *firewallLocalCIDR + Hosts map[string]*firewallLocalCIDR + Groups []*firewallGroups + CIDR *cidr.Tree4[*firewallLocalCIDR] +} + +type firewallGroups struct { + Groups []string + LocalCIDR *firewallLocalCIDR } // Even though ports are uint16, int32 maps are faster for lookup // Plus we can use `-1` for fragment rules type firewallPort map[int32]*FirewallCA +type firewallLocalCIDR struct { + Any bool + LocalCIDR *cidr.Tree4[struct{}] +} + // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { //TODO: error on 0 duration @@ -133,9 +144,16 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4() + localIps := cidr.NewTree4[struct{}]() + var assignedCIDR *net.IPNet for _, ip := range c.Details.Ips { - localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} + localIps.AddCIDR(ipNet, struct{}{}) + + if assignedCIDR == nil { + // Only grabbing the first one in the cert since any more than that currently has undefined behavior + assignedCIDR = ipNet + } } for _, n := range c.Details.Subnets { @@ -153,9 +171,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, localIps: localIps, + assignedCIDR: assignedCIDR, + hasSubnets: len(c.Details.Subnets) > 0, l: l, - metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), incomingMetrics: firewallMetrics{ droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), @@ -179,6 +198,31 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf //TODO: max_connections ) + //TODO: Flip to false after v1.9 release + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + + inboundAction := c.GetString("firewall.inbound_action", "drop") + switch inboundAction { + case "reject": + fw.InSendReject = true + case "drop": + fw.InSendReject = false + default: + l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") + fw.InSendReject = false + } + + outboundAction := c.GetString("firewall.outbound_action", "drop") + switch outboundAction { + case "reject": + fw.OutSendReject = true + case "drop": + fw.OutSendReject = false + default: + l.WithField("action", inboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + fw.OutSendReject = false + } + err := AddFirewallRulesFromConfig(l, false, c, fw) if err != nil { return nil, err @@ -193,18 +237,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" if ip != nil { sIp = ip.String() } + lIp := "" + if localIp != nil { + lIp = localIp.String() + } // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha, + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha, ) f.rules += ruleString + "\n" @@ -212,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}). Info("Firewall rule added") var ( @@ -239,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } - return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha) + return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha) } // GetRuleHash returns a hash representation of all inbound and outbound rules @@ -248,6 +296,18 @@ func (f *Firewall) GetRuleHash() string { return hex.EncodeToString(sum[:]) } +// GetRuleHashFNV returns a uint32 FNV-1 hash representation the rules, for use as a metric value +func (f *Firewall) GetRuleHashFNV() uint32 { + h := fnv.New32a() + h.Write([]byte(f.rules)) + return h.Sum32() +} + +// GetRuleHashes returns both the sha256 and FNV-1 hashes, suitable for logging +func (f *Firewall) GetRuleHashes() string { + return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) +} + func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { @@ -277,8 +337,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) } - if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" { - return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i) + if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { + return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) } if len(r.Groups) > 0 { @@ -330,7 +390,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } } - err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha) + var localCidr *net.IPNet + if r.LocalCidr != "" { + _, localCidr, err = net.ParseCIDR(r.LocalCidr) + if err != nil { + return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) + } + } + + err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha) if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } @@ -345,15 +413,16 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming, h, caPool, localCache) { + if f.inConns(fp, h, caPool, localCache) { return nil } // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - if remoteCidr.Contains(fp.RemoteIP) == nil { + ok, _ := remoteCidr.Contains(fp.RemoteIP) + if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP } @@ -366,7 +435,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // Make sure we are supposed to be handling this local ip address - if f.localIps.Contains(fp.LocalIP) == nil { + ok, _ := f.localIps.Contains(fp.LocalIP) + if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP } @@ -383,7 +453,7 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // We always want to conntrack since it is a faster operation - f.addConn(packet, fp, incoming) + f.addConn(fp, incoming) return nil } @@ -409,9 +479,10 @@ func (f *Firewall) EmitStats() { conntrack.Unlock() metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) + metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -471,11 +542,6 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h * switch fp.Protocol { case firewall.ProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) - if incoming { - f.checkTCPRTT(c, packet) - } else { - setTCPRTTTracking(c, packet) - } case firewall.ProtoUDP: c.Expires = time.Now().Add(f.UDPTimeout) default: @@ -491,16 +557,13 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h * return true } -func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { var timeout time.Duration c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: timeout = f.TCPTimeout - if !incoming { - setTCPRTTTracking(c, packet) - } case firewall.ProtoUDP: timeout = f.UDPTimeout default: @@ -570,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -583,7 +646,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, } } - if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil { + if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil { return err } } @@ -614,12 +677,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ - Hosts: make(map[string]struct{}), - Groups: make([][]string, 0), - CIDR: cidr.NewTree4(), + Hosts: make(map[string]*firewallLocalCIDR), + Groups: make([]*firewallGroups, 0), + CIDR: cidr.NewTree4[*firewallLocalCIDR](), } } @@ -628,14 +691,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam fc.Any = fr() } - return fc.Any.addRule(groups, host, ip) + return fc.Any.addRule(f, groups, host, ip, localIp) } if caSha != "" { if _, ok := fc.CAShas[caSha]; !ok { fc.CAShas[caSha] = fr() } - err := fc.CAShas[caSha].addRule(groups, host, ip) + err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp) if err != nil { return err } @@ -645,7 +708,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam if _, ok := fc.CANames[caName]; !ok { fc.CANames[caName] = fr() } - err := fc.CANames[caName].addRule(groups, host, ip) + err := fc.CANames[caName].addRule(f, groups, host, ip, localIp) if err != nil { return err } @@ -677,29 +740,56 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { - if fr.Any { - return nil +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { + flc := func() *firewallLocalCIDR { + return &firewallLocalCIDR{ + LocalCIDR: cidr.NewTree4[struct{}](), + } } if fr.isAny(groups, host, ip) { - fr.Any = true - // If it's any we need to wipe out any pre-existing rules to save on memory - fr.Groups = make([][]string, 0) - fr.Hosts = make(map[string]struct{}) - fr.CIDR = cidr.NewTree4() - } else { - if len(groups) > 0 { - fr.Groups = append(fr.Groups, groups) + if fr.Any == nil { + fr.Any = flc() } - if host != "" { - fr.Hosts[host] = struct{}{} + return fr.Any.addRule(f, localCIDR) + } + + if len(groups) > 0 { + nlc := flc() + err := nlc.addRule(f, localCIDR) + if err != nil { + return err + } + + fr.Groups = append(fr.Groups, &firewallGroups{ + Groups: groups, + LocalCIDR: nlc, + }) + } + + if host != "" { + nlc := fr.Hosts[host] + if nlc == nil { + nlc = flc() + } + err := nlc.addRule(f, localCIDR) + if err != nil { + return err } + fr.Hosts[host] = nlc + } - if ip != nil { - fr.CIDR.AddCIDR(ip, struct{}{}) + if ip != nil { + _, nlc := fr.CIDR.GetCIDR(ip) + if nlc == nil { + nlc = flc() + } + err := nlc.addRule(f, localCIDR) + if err != nil { + return err } + fr.CIDR.AddCIDR(ip, nlc) } return nil @@ -733,7 +823,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } // Shortcut path for if groups, hosts, or cidr contained an `any` - if fr.Any { + if fr.Any.match(p, c) { return true } @@ -741,7 +831,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool for _, sg := range fr.Groups { found := false - for _, g := range sg { + for _, g := range sg.Groups { if _, ok := c.Details.InvertedGroups[g]; !ok { found = false break @@ -750,35 +840,64 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found = true } - if found { + if found && sg.LocalCIDR.match(p, c) { return true } } if fr.Hosts != nil { - if _, ok := fr.Hosts[c.Details.Name]; ok { - return true + if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc.match(p, c) { + return true + } } } - if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil { + return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { + return flc.match(p, c) + }) +} + +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { + if localIp == nil { + if !f.hasSubnets || f.defaultLocalCIDRAny { + flc.Any = true + return nil + } + + localIp = f.assignedCIDR + } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { + flc.Any = true + } + + flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + return nil +} + +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { + if flc == nil { + return false + } + + if flc.Any { return true } - // No host, group, or cidr matched, bye bye - return false + ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + return ok } type rule struct { - Port string - Code string - Proto string - Host string - Group string - Groups []string - Cidr string - CAName string - CASha string + Port string + Code string + Proto string + Host string + Group string + Groups []string + Cidr string + LocalCidr string + CAName string + CASha string } func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { @@ -802,6 +921,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er r.Proto = toString("proto", m) r.Host = toString("host", m) r.Cidr = toString("cidr", m) + r.LocalCidr = toString("local_cidr", m) r.CAName = toString("ca_name", m) r.CASha = toString("ca_sha", m) @@ -880,42 +1000,3 @@ func parsePort(s string) (startPort, endPort int32, err error) { return } - -// TODO: write tests for these -func setTCPRTTTracking(c *conn, p []byte) { - if c.Seq != 0 { - return - } - - ihl := int(p[0]&0x0f) << 2 - - // Don't track FIN packets - if p[ihl+13]&tcpFIN != 0 { - return - } - - c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8]) - c.Sent = time.Now() -} - -func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { - if c.Seq == 0 { - return false - } - - ihl := int(p[0]&0x0f) << 2 - if p[ihl+13]&tcpACK == 0 { - return false - } - - // Deal with wrap around, signed int cuts the ack window in half - // 0 is a bad ack, no data acknowledged - // positive number is a bad ack, ack is over half the window away - if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 { - return false - } - - f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds()) - c.Seq = 0 - return true -} diff --git a/firewall_test.go b/firewall_test.go index 4f24ac03f..b5beff61e 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -2,14 +2,12 @@ package nebula import ( "bytes" - "encoding/binary" "errors" "math" "net" "testing" "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -69,67 +67,57 @@ func TestFirewall_AddRule(t *testing.T) { _, ti, _ := net.ParseCIDR("1.2.3.4/32") - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) // An empty rule is any - assert.True(t, fw.InRules.TCP[1].Any.Any) + assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) - assert.False(t, fw.InRules.UDP[1].Any.Any) - assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + assert.Nil(t, fw.InRules.UDP[1].Any.Any) + assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) - assert.False(t, fw.InRules.ICMP[1].Any.Any) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", "")) - assert.False(t, fw.OutRules.AnyProto[1].Any.Any) - assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) - assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) + ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) - assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) + ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) - assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - // Set any and clear fields fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) - assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) - assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) - - // run twice just to make sure - //TODO: these ANY rules should clear the CA firewall portion - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any.Any) - assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) - assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any.Any) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any.Any) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -138,12 +126,12 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - 10, - 90, - firewall.ProtoUDP, - false, + LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, } ipNet := net.IPNet{ @@ -169,78 +157,88 @@ func TestFirewall_Drop(t *testing.T) { h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteIP p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { + f := &Firewall{} ft := FirewallTable{ TCP: firewallPort{}, } _, n, _ := net.ParseCIDR("172.1.1.1/32") - _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") + goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { + // This benchmark is showing us the cost of failing to match the protocol c := &cert.NebulaCertificate{} for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) } }) - b.Run("fail on port", func(b *testing.B) { + b.Run("pass proto, fail on port", func(b *testing.B) { + // This benchmark is showing us the cost of matching a specific protocol but failing to match the port c := &cert.NebulaCertificate{} for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) } }) - b.Run("fail all group, name, and cidr", func(b *testing.B) { + b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { + c := &cert.NebulaCertificate{} + ip, _, _ := net.ParseCIDR("9.254.254.254/32") + lip := iputil.Ip2VpnIp(ip) + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + } + }) + + b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { _, ip, _ := net.ParseCIDR("9.254.254.254/32") c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ @@ -250,51 +248,49 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) - b.Run("pass on group", func(b *testing.B) { + b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { + _, ip, _ := net.ParseCIDR("9.254.254.254/32") c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, + InvertedGroups: map[string]struct{}{"nope": {}}, Name: "nope", + Ips: []*net.IPNet{ip}, }, } for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) } }) - b.Run("pass on name", func(b *testing.B) { + b.Run("pass on group on any local cidr", func(b *testing.B) { c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "good-host", + InvertedGroups: map[string]struct{}{"good-group": {}}, + Name: "nope", }, } for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) - b.Run("pass on ip", func(b *testing.B) { - ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + b.Run("pass on group on specific local cidr", func(b *testing.B) { c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "good-host", + InvertedGroups: map[string]struct{}{"good-group": {}}, + Name: "nope", }, } for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) } }) - _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") - - b.Run("pass on ip with any port", func(b *testing.B) { - ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + b.Run("pass on name", func(b *testing.B) { c := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ InvertedGroups: map[string]struct{}{"nope": {}}, @@ -302,9 +298,63 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) + // + //b.Run("pass on ip", func(b *testing.B) { + // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + // c := &cert.NebulaCertificate{ + // Details: cert.NebulaCertificateDetails{ + // InvertedGroups: map[string]struct{}{"nope": {}}, + // Name: "good-host", + // }, + // } + // for n := 0; n < b.N; n++ { + // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) + // } + //}) + // + //b.Run("pass on local ip", func(b *testing.B) { + // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + // c := &cert.NebulaCertificate{ + // Details: cert.NebulaCertificateDetails{ + // InvertedGroups: map[string]struct{}{"nope": {}}, + // Name: "good-host", + // }, + // } + // for n := 0; n < b.N; n++ { + // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) + // } + //}) + // + //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") + // + //b.Run("pass on ip with any port", func(b *testing.B) { + // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + // c := &cert.NebulaCertificate{ + // Details: cert.NebulaCertificateDetails{ + // InvertedGroups: map[string]struct{}{"nope": {}}, + // Name: "good-host", + // }, + // } + // for n := 0; n < b.N; n++ { + // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) + // } + //}) + // + //b.Run("pass on local ip with any port", func(b *testing.B) { + // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + // c := &cert.NebulaCertificate{ + // Details: cert.NebulaCertificateDetails{ + // InvertedGroups: map[string]struct{}{"nope": {}}, + // Name: "good-host", + // }, + // } + // for n := 0; n < b.N; n++ { + // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) + // } + //}) } func TestFirewall_Drop2(t *testing.T) { @@ -313,12 +363,12 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - 10, - 90, - firewall.ProtoUDP, - false, + LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, } ipNet := net.IPNet{ @@ -356,14 +406,14 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule) + assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -372,12 +422,12 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - 1, - 1, - firewall.ProtoUDP, - false, + LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalPort: 1, + RemotePort: 1, + Protocol: firewall.ProtoUDP, + Fragment: false, } ipNet := net.IPNet{ @@ -438,18 +488,18 @@ func TestFirewall_Drop3(t *testing.T) { h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -458,12 +508,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - 10, - 90, - firewall.ProtoUDP, - false, + LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, } ipNet := net.IPNet{ @@ -489,34 +539,34 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } func BenchmarkLookup(b *testing.B) { @@ -653,7 +703,7 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") + assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) @@ -677,6 +727,12 @@ func TestNewFirewallFromConfig(t *testing.T) { _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + // Test local_cidr parse error + conf = config.NewC(l) + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} + _, err = NewFirewallFromConfig(l, c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} @@ -691,63 +747,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + + // Test adding rule with cidr + cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + + // Test adding rule with local_cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -757,97 +828,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } -func TestTCPRTTTracking(t *testing.T) { - b := make([]byte, 200) - - // Max ip IHL (60 bytes) and tcp IHL (60 bytes) - b[0] = 15 - b[60+12] = 15 << 4 - f := Firewall{ - metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)), - } - - // Set SEQ to 1 - binary.BigEndian.PutUint32(b[60+4:60+8], 1) - - c := &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, uint32(1), c.Seq) - - // Bad ack - no ack flag - binary.BigEndian.PutUint32(b[60+8:60+12], 80) - assert.False(t, f.checkTCPRTT(c, b)) - - // Bad ack, number is too low - binary.BigEndian.PutUint32(b[60+8:60+12], 0) - b[60+13] = uint8(0x10) - assert.False(t, f.checkTCPRTT(c, b)) - - // Good ack - binary.BigEndian.PutUint32(b[60+8:60+12], 80) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to 1 - binary.BigEndian.PutUint32(b[60+4:60+8], 1) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, uint32(1), c.Seq) - - // Good acks - binary.BigEndian.PutUint32(b[60+8:60+12], 81) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 - 20 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0)-20, c.Seq) - - // Good acks - binary.BigEndian.PutUint32(b[60+8:60+12], 81) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 / 2 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Below - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Halfway below - binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0)) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Halfway above is ok - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0), c.Seq) - - // Halfway + 1 above - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0), c.Seq) - - // Halfway above - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) -} - func TestFirewall_convertRule(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} @@ -892,6 +872,7 @@ type addRuleCall struct { groups []string host string ip *net.IPNet + localIp *net.IPNet caName string caSha string } @@ -901,7 +882,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, @@ -910,6 +891,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end groups: groups, host: host, ip: ip, + localIp: localIp, caName: caName, caSha: caSha, } diff --git a/go.mod b/go.mod index 8e8a354b7..4676616ff 100644 --- a/go.mod +++ b/go.mod @@ -1,47 +1,54 @@ module github.com/slackhq/nebula -go 1.19 +go 1.22.0 + +toolchain go1.22.2 require ( + dario.cat/mergo v1.0.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 - github.com/flynn/noise v1.0.0 + github.com/flynn/noise v1.1.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 - github.com/imdario/mergo v0.3.13 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.50 + github.com/miekg/dns v1.1.59 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.14.0 + github.com/prometheus/client_golang v1.18.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.8.1 - github.com/vishvananda/netlink v1.1.0 - golang.org/x/crypto v0.3.0 - golang.org/x/net v0.2.0 - golang.org/x/sys v0.2.0 - golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 + github.com/stretchr/testify v1.9.0 + github.com/vishvananda/netlink v1.2.1-beta.2 + golang.org/x/crypto v0.22.0 + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 + golang.org/x/net v0.24.0 + golang.org/x/sync v0.7.0 + golang.org/x/sys v0.19.0 + golang.org/x/term v0.19.0 + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 + golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.28.1 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v2 v2.4.0 + gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/golang/protobuf v1.5.2 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/google/btree v1.1.2 // indirect + github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.3.0 // indirect - github.com/prometheus/common v0.37.0 // indirect - github.com/prometheus/procfs v0.8.0 // indirect - github.com/vishvananda/netns v0.0.1 // indirect - golang.org/x/mod v0.7.0 // indirect - golang.org/x/term v0.2.0 // indirect - golang.org/x/tools v0.3.0 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.45.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/mod v0.16.0 // indirect + golang.org/x/time v0.5.0 // indirect + golang.org/x/tools v0.19.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3c5eaa75c..773a56137 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,6 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -46,108 +14,52 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= -github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= +github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk= -github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60= @@ -158,21 +70,21 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= -github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= -github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= -github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= +github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= +github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= +github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= +github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg= @@ -186,364 +98,147 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= -github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= +github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= +github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= -github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE= -github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= +github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= +github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= -github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= -github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= -github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= -github.com/vishvananda/netns v0.0.1 h1:JDkWS7Axy5ziNM3svylLhpSgqjPDb+BgVUbXoDo+iPw= -github.com/vishvananda/netns v0.0.1/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= +github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= -golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= -golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q= +golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM= -golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= -golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -552,16 +247,7 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= +gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g= +gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 1cad0db0c..000000000 --- a/handshake.go +++ /dev/null @@ -1,31 +0,0 @@ -package nebula - -import ( - "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" -) - -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) { - // First remote allow list check before we know the vpnIp - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { - f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") - return - } - } - - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(f, addr, via, packet, h) - case 2: - newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteHostInfo(newHostinfo) - } - } - } - -} diff --git a/handshake_ix.go b/handshake_ix.go index 11a16a6dc..8727b16f1 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -4,6 +4,7 @@ import ( "time" "github.com/flynn/noise" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" @@ -13,27 +14,22 @@ import ( // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { - // This queries the lighthouse if we don't know a remote for the host - // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send - // more quickly, effect is a quicker handshake. - if hostinfo.remote == nil { - f.lightHouse.QueryServer(vpnIp, f) - } - - err := f.handshakeManager.AddIndexHostInfo(hostinfo) +func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { + err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") - return + return false } - ci := hostinfo.ConnectionState + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + hh.hostinfo.ConnectionState = ci hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hostinfo.localIndexId, + InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: ci.certState.rawCertificateNoKey, + Cert: certState.RawCertificateNoKey, } hsBytes := []byte{} @@ -44,9 +40,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hsBytes, err = hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return + return false } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) @@ -54,22 +50,23 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return + return false } // We are sending handshake packet 1, so we don't expect to receive // handshake packet 1 from the responder ci.window.Update(f.l, 1) - hostinfo.HandshakePacket[0] = msg - hostinfo.HandshakeReady = true - hostinfo.handshakeStart = time.Now() + hh.hostinfo.HandshakePacket[0] = msg + hh.ready = true + return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) { - ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) +func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -91,11 +88,16 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Invalid certificate from host") + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid certificate from host") return } vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) @@ -143,9 +145,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b }, } - hostinfo.Lock() - defer hostinfo.Unlock() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -155,7 +154,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hs.Details.Cert = certState.RawCertificateNoKey // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) @@ -207,25 +206,16 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert) - // Only overwrite existing record if we should win the handshake race - overwrite := vpnIp > f.myVpnIp - existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) + existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { switch err { case ErrAlreadySeen: - // Update remote if preferred (Note we have to switch to locking - // the existing hostinfo, and then switch back so the defer Unlock - // higher in this function still works) - hostinfo.Unlock() - existing.Lock() // Update remote if preferred if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } - existing.Unlock() - hostinfo.Lock() msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) @@ -242,14 +232,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b } return } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return @@ -280,16 +269,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). Error("Failed to add HostInfo due to localIndex collision") return - case ErrExistingHandshake: - // We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Prevented a pending handshake race") - return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here @@ -323,41 +302,41 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + hostinfo.ConnectionState.messageCounter.Store(2) + hostinfo.remotes.ResetBlockedRemotes() return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool { - if hostinfo == nil { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { + if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true } - hostinfo.Lock() - defer hostinfo.Unlock() + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo if addr != nil { if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") @@ -366,22 +345,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * } ci := hostinfo.ConnectionState - if ci.ready { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Info("Handshake is already complete") - - // Update remote if preferred - if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets - return false - } - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). @@ -412,11 +375,16 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Invalid certificate from host") + e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Error("Invalid certificate from host") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true @@ -435,34 +403,30 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * Info("Incorrect host responded to handshake") // Release our old handshake from pending, it should not continue - f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - //TODO: this adds it to the timer wheel in a way that aggressively retries - newHostInfo := f.getOrHandshake(hostinfo.vpnIp) - newHostInfo.Lock() - - // Block the current used address - newHostInfo.remotes = hostinfo.remotes - newHostInfo.remotes.BlockRemote(addr) + f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { + //TODO: this doesnt know if its being added or is being used for caching a packet + // Block the current used address + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(addr) - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + // Get the correct remote list for the host we did handshake with + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). - Info("Blocked addresses for handshakes") + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). + Info("Blocked addresses for handshakes") - // Swap the packet store to benefit the original intended recipient - hostinfo.ConnectionState.queueLock.Lock() - newHostInfo.packetStore = hostinfo.packetStore - hostinfo.packetStore = []*cachedPacket{} - hostinfo.ConnectionState.queueLock.Unlock() + // Swap the packet store to benefit the original intended recipient + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp - f.sendCloseTunnel(hostinfo) - newHostInfo.Unlock() + // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnIp = vpnIp + f.sendCloseTunnel(hostinfo) + }) return true } @@ -470,7 +434,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * // Mark packet 2 as seen so it doesn't show up as missed ci.window.Update(f.l, 2) - duration := time.Since(hostinfo.handshakeStart).Nanoseconds() + duration := time.Since(hh.startTime).Nanoseconds() f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -478,7 +442,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). - WithField("sentCachedPackets", len(hostinfo.packetStore)). + WithField("sentCachedPackets", len(hh.packetStore)). Info("Handshake message received") hostinfo.remoteIndexId = hs.Details.ResponderIndex @@ -493,17 +457,32 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * if addr != nil { hostinfo.SetRemote(addr) } else { - via2 := via.(*ViaSender) - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) } // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp - //TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok? f.handshakeManager.Complete(hostinfo, f) - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + + hostinfo.ConnectionState.messageCounter.Store(2) + + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + } + + if len(hh.packetStore) > 0 { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.ResetBlockedRemotes() f.metricHandshakes.Update(duration) return false diff --git a/handshake_manager.go b/handshake_manager.go index 432584102..640227a7e 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "net" + "sync" "time" "github.com/rcrowley/go-metrics" @@ -42,24 +43,68 @@ type HandshakeConfig struct { } type HandshakeManager struct { - pendingHostMap *HostMap + // Mutex for interacting with the vpnIps and indexes maps + sync.RWMutex + + vpnIps map[iputil.VpnIp]*HandshakeHostInfo + indexes map[uint32]*HandshakeHostInfo + mainHostMap *HostMap lightHouse *LightHouse - outside *udp.Conn + outside udp.Conn config HandshakeConfig OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter + f *Interface l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan iputil.VpnIp } -func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager { +type HandshakeHostInfo struct { + sync.Mutex + + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + + hostinfo *HostInfo +} + +func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { + if len(hh.packetStore) < 100 { + tempPacket := make([]byte, len(packet)) + copy(tempPacket, packet) + + hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", true). + Debugf("Packet store") + } + + } else { + m.dropped.Inc(1) + + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", false). + Debugf("Packet store") + } + } +} + +func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), + vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -73,7 +118,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -82,66 +127,80 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { case <-ctx.Done(): return case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, f, true) + c.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now, f) + c.NextOutboundHandshakeTimerTick(now) } } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { +func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { + // First remote allow list check before we know the vpnIp + if addr != nil { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + + switch h.Subtype { + case header.HandshakeIXPSK0: + switch h.MessageCounter { + case 1: + ixHandshakeStage1(hm.f, addr, via, packet, h) + + case 2: + newHostinfo := hm.queryIndex(h.RemoteIndex) + tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) + if tearDown && newHostinfo != nil { + hm.DeleteHostInfo(newHostinfo.hostinfo) + } + } + } +} + +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, f, false) + c.handleOutbound(vpnIp, false) } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { - hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { - return - } - hostinfo.Lock() - defer hostinfo.Unlock() - - // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. - if hostinfo.HandshakeComplete { - // Ensure we don't exist in the pending hostmap anymore since we have completed - c.pendingHostMap.DeleteHostInfo(hostinfo) - return - } - - // Check if we have a handshake packet to transmit yet - if !hostinfo.HandshakeReady { - // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly - // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) +func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { + hh := hm.queryVpnIp(vpnIp) + if hh == nil { return } + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo // If we are out of time, clean up - if hostinfo.HandshakeCounter >= c.config.retries { - hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). + if hh.counter >= hm.config.retries { + hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). + WithField("initiatorIndex", hh.hostinfo.localIndexId). + WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). + WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). Info("Handshake timed out") - c.metricTimedOut.Inc(1) - c.pendingHostMap.DeleteHostInfo(hostinfo) + hm.metricTimedOut.Inc(1) + hm.DeleteHostInfo(hostinfo) return } - // We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific - // optimization for a fast lighthouse reply - //TODO: it would feel better to do this once, anytime, as our delay increases over time - if lighthouseTriggered && hostinfo.HandshakeCounter > 0 { - // If we didn't return here a lighthouse could cause us to aggressively send handshakes - return + // Increment the counter to increase our delay, linear backoff + hh.counter++ + + // Check if we have a handshake packet to transmit yet + if !hh.ready { + if !ixHandshakeStage0(hm.f, hh) { + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) + return + } } // Get a remotes object if we don't already have one. @@ -149,24 +208,38 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) } - //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) - if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { + remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) + remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + + // We only care about a lighthouse trigger if we have new remotes to send to. + // This is a very specific optimization for a fast lighthouse reply. + if lighthouseTriggered && !remotesHaveChanged { + // If we didn't return here a lighthouse could cause us to aggressively send handshakes + return + } + + hh.lastRemotes = remotes + + // TODO: this will generate a load of queries for hosts with only 1 ip + // (such as ones registered to the lighthouse with only a private IP) + // So we only do it one time after attempting 5 handshakes already. + if len(remotes) <= 1 && hh.counter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIp, f) + hm.lightHouse.QueryServer(vpnIp) } - // Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply + // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr - hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { - c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { + hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(c.l).WithField("udpAddr", addr). + hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake message") @@ -176,103 +249,180 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l } }) - // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout - if len(sentTo) > 0 { - hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, + // so only log when the list of remotes has changed + if remotesHaveChanged { + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") + } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Debug("Handshake message sent") } - if c.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(c.l).WithField("relayIps", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { + hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == c.lightHouse.myVpnIp { + if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { continue } - relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay) - if err != nil || relayHostInfo.remote == nil { - hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target.") - f.Handshake(*relay) + relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) + if relayHostInfo == nil || relayHostInfo.remote == nil { + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hm.f.Handshake(*relay) continue } // Check the relay HostInfo to see if we already established a relay through it if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { switch existingRelay.State { case Established: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay") - f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu)) + // This must send over the hostinfo, not over hm.Hosts[ip] + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, + "relayTo": vpnIp, + "initiatorRelayIndex": existingRelay.LocalIndex, + "relay": *relay}). + Info("send CreateRelayRequest") } default: - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("vpnIp", vpnIp). WithField("state", existingRelay.State). - WithField("relayVpnIp", relayHostInfo.vpnIp). + WithField("relay", relayHostInfo.vpnIp). Errorf("Relay unexpected state") } } else { // No relays exist or requested yet. if relayHostInfo.remote != nil { - idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested) + idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu)) + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, + "relayTo": vpnIp, + "initiatorRelayIndex": idx, + "relay": *relay}). + Info("send CreateRelayRequest") } } } } } - // Increment the counter to increase our delay, linear backoff - hostinfo.HandshakeCounter++ - // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { - //TODO: feel like we dupe handshake real fast in a tight loop, why? - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) } } -func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { - hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) +// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present +// The 2nd argument will be true if the hostinfo is ready to transmit traffic +func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { + // Check the main hostmap and maintain a read lock if our host is not there + hm.mainHostMap.RLock() + if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { + hm.mainHostMap.RUnlock() + // Do not attempt promotion if you are a lighthouse + if !hm.lightHouse.amLighthouse { + h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) + } + return h, true + } + + defer hm.mainHostMap.RUnlock() + return hm.StartHandshake(vpnIp, cacheCb), false +} + +// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip +func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { + hm.Lock() + + if hh, ok := hm.vpnIps[vpnIp]; ok { + // We are already trying to handshake with this vpn ip + if cacheCb != nil { + cacheCb(hh) + } + hm.Unlock() + return hh.hostinfo + } + + hostinfo := &HostInfo{ + vpnIp: vpnIp, + HandshakePacket: make(map[uint8][]byte, 0), + relayState: RelayState{ + relays: map[iputil.VpnIp]struct{}{}, + relayForByIp: map[iputil.VpnIp]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + } + + hh := &HandshakeHostInfo{ + hostinfo: hostinfo, + startTime: time.Now(), + } + hm.vpnIps[vpnIp] = hh + hm.metricInitiated.Inc(1) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + + if cacheCb != nil { + cacheCb(hh) + } + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + if !doTrigger { + // Add any calculated remotes, and trigger early handshake if one found + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + } - if created { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - c.metricInitiated.Inc(1) + if doTrigger { + select { + case hm.trigger <- vpnIp: + default: + } } + hm.Unlock() + hm.lightHouse.QueryServer(vpnIp) return hostinfo } @@ -280,7 +430,6 @@ var ( ErrExistingHostInfo = errors.New("existing hostinfo") ErrAlreadySeen = errors.New("already seen") ErrLocalIndexCollision = errors.New("local index collision") - ErrExistingHandshake = errors.New("existing handshake") ) // CheckAndComplete checks for any conflicts in the main and pending hostmap @@ -294,22 +443,27 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() +func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { c.mainHostMap.Lock() defer c.mainHostMap.Unlock() + c.Lock() + defer c.Unlock() // Check if we already have a tunnel with this vpn ip existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] if found && existingHostInfo != nil { - // Is it just a delayed handshake packet? - if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { - return existingHostInfo, ErrAlreadySeen + testHostInfo := existingHostInfo + for testHostInfo != nil { + // Is it just a delayed handshake packet? + if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) { + return testHostInfo, ErrAlreadySeen + } + + testHostInfo = testHostInfo.next } // Is this a newer handshake? - if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime { + if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime && !existingHostInfo.ConnectionState.initiator { return existingHostInfo, ErrExistingHostInfo } @@ -322,8 +476,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingIndex, ErrLocalIndexCollision } - existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] - if found && existingIndex != hostinfo { + existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } @@ -337,90 +491,54 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket Info("New host shadows existing host remoteIndex") } - // Check if we are also handshaking with this vpn ip - pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp] - if found && pendingHostInfo != nil { - if !overwrite { - // We won, let our pending handshake win - return pendingHostInfo, ErrExistingHandshake - } - - // We lost, take this handshake and move any cached packets over so they get sent - pendingHostInfo.ConnectionState.queueLock.Lock() - hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...) - c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo) - pendingHostInfo.ConnectionState.queueLock.Unlock() - pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel") - } - - if existingHostInfo != nil { - // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) - delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) - delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) - for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() { - delete(c.mainHostMap.Relays, relayIdx) - } - } - - c.mainHostMap.addHostInfo(hostinfo, f) + c.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the -// pendingHostMap -func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] - if found && existingHostInfo != nil { - // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) - delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) - delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) - for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() { - delete(c.mainHostMap.Relays, relayIdx) - } - } - - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] +// pendingHostMap. An existing hostinfo is returned if there was one. +func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() + + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.addHostInfo(hostinfo, f) - c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) + // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. + hm.unlockedDeleteHostInfo(hostinfo) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) } -// AddIndexHostInfo generates a unique localIndexId for this HostInfo +// allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() - c.mainHostMap.RLock() - defer c.mainHostMap.RUnlock() +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { + hm.mainHostMap.RLock() + defer hm.mainHostMap.RUnlock() + hm.Lock() + defer hm.Unlock() for i := 0; i < 32; i++ { - index, err := generateIndex(c.l) + index, err := generateIndex(hm.l) if err != nil { return err } - _, inPending := c.pendingHostMap.Indexes[index] - _, inMain := c.mainHostMap.Indexes[index] + _, inPending := hm.indexes[index] + _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { - h.localIndexId = index - c.pendingHostMap.Indexes[index] = h + hh.hostinfo.localIndexId = index + hm.indexes[index] = hh return nil } } @@ -428,22 +546,90 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - c.pendingHostMap.addRemoteIndexHostInfo(index, h) +func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + c.Lock() + defer c.Unlock() + c.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - //l.Debugln("Deleting pending hostinfo :", hostinfo) - c.pendingHostMap.DeleteHostInfo(hostinfo) +func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + delete(c.vpnIps, hostinfo.vpnIp) + if len(c.vpnIps) == 0 { + c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} + } + + delete(c.indexes, hostinfo.localIndexId) + if len(c.vpnIps) == 0 { + c.indexes = map[uint32]*HandshakeHostInfo{} + } + + if c.l.Level >= logrus.DebugLevel { + c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), + "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + Debug("Pending hostmap hostInfo deleted") + } +} + +func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { + hh := hm.queryVpnIp(vpnIp) + if hh != nil { + return hh.hostinfo + } + return nil + +} + +func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.vpnIps[vpnIp] } -func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) { - return c.pendingHostMap.QueryIndex(index) +func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo { + hh := hm.queryIndex(index) + if hh != nil { + return hh.hostinfo + } + return nil +} + +func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.indexes[index] +} + +func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { + return c.mainHostMap.GetPreferredRanges() +} + +func (c *HandshakeManager) ForEachVpnIp(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.vpnIps { + f(v.hostinfo) + } +} + +func (c *HandshakeManager) ForEachIndex(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.indexes { + f(v.hostinfo) + } } func (c *HandshakeManager) EmitStats() { - c.pendingHostMap.EmitStats("pending") - c.mainHostMap.EmitStats("main") + c.RLock() + hostLen := len(c.vpnIps) + indexLen := len(c.indexes) + c.RUnlock() + + metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) + c.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 413a50abd..9a6335757 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" @@ -14,96 +15,55 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := newTestLighthouse() - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) + mainHM := newHostMap(l, vpncidr) + mainHM.preferredRanges.Store(&preferredRanges) - now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) + lh := newTestLighthouse() - var initCalled bool - initFunc := func(*HostInfo) { - initCalled = true + cs := &CertState{ + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } - i := blah.AddVpnIp(ip, initFunc) - assert.True(t, initCalled) + blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l} + blah.f.pki.cs.Store(cs) - initCalled = false - i2 := blah.AddVpnIp(ip, initFunc) - assert.False(t, initCalled) + now := time.Now() + blah.NextOutboundHandshakeTimerTick(now) + + i := blah.StartHandshake(ip, nil) + i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList() - i.HandshakeReady = true + i.remotes = NewRemoteList(nil) // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) // Confirm they are in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right for i := 1; i <= DefaultHandshakeRetries+1; i++ { now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval) - blah.NextOutboundHandshakeTimerTick(now, mw) + blah.NextOutboundHandshakeTimerTick(now) } // Confirm they are still in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Tick 1 more time, a minute will certainly flush it out - blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw) + blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute)) // Confirm they have been removed - assert.NotContains(t, blah.pendingHostMap.Hosts, ip) -} - -func Test_NewHandshakeManagerTrigger(t *testing.T) { - l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := newTestLighthouse() - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) - - assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - hi := blah.AddVpnIp(ip, nil) - hi.HandshakeReady = true - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") - - // Trigger the same method the channel will but, this should set our remotes pointer - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt") - assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer") - - // Make sure the trigger doesn't double schedule the timer entry - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - uaddr := udp.NewAddrFromString("10.1.1.1:4242") - hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) - - // We now have remotes but only the first trigger should have pushed things forward - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt") - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) + assert.NotContains(t, blah.vpnIps, ip) } func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { @@ -124,7 +84,11 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess return } -func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { + return +} + +func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { return } diff --git a/hostmap.go b/hostmap.go index 372333eb2..589a12463 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,9 +1,7 @@ package nebula import ( - "context" "errors" - "fmt" "net" "sync" "sync/atomic" @@ -13,15 +11,22 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 -const PromoteEvery = 1000 -const ReQueryEvery = 5000 +const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address +const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse +const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 +const maxRecvError = 4 + +// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip +// 5 allows for an initial handshake and each host pair re-handshaking twice +const MaxHostInfosPerVpnIp = 5 // How long we should prevent roaming back to the previous IP. // This helps prevent flapping due to packets already in flight @@ -29,6 +34,7 @@ const RoamingSuppressSeconds = 2 const ( Requested = iota + PeerRequested Established ) @@ -48,17 +54,18 @@ type Relay struct { type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps - name string Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo Hosts map[iputil.VpnIp]*HostInfo - preferredRanges []*net.IPNet + preferredRanges atomic.Pointer[[]*net.IPNet] vpnCIDR *net.IPNet - metricsEnabled bool l *logrus.Logger } +// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay +// struct, make a copy of an existing value, edit the fileds in the copy, and +// then store a pointer to the new copy in both realyForBy* maps. type RelayState struct { sync.RWMutex @@ -73,6 +80,16 @@ func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { delete(rs.relays, ip) } +func (rs *RelayState) CopyAllRelayFor() []*Relay { + rs.RLock() + defer rs.RUnlock() + ret := make([]*Relay, 0, len(rs.relayForByIdx)) + for _, r := range rs.relayForByIdx { + ret = append(ret, r) + } + return ret +} + func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { rs.RLock() defer rs.RUnlock() @@ -119,13 +136,43 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { rs.Lock() defer rs.Unlock() - relay, ok := rs.relayForByIdx[localIdx] + r, ok := rs.relayForByIdx[localIdx] if !ok { return iputil.VpnIp(0), false } delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, relay.PeerIp) - return relay.PeerIp, true + delete(rs.relayForByIp, r.PeerIp) + return r.PeerIp, true +} + +func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { + rs.Lock() + defer rs.Unlock() + r, ok := rs.relayForByIp[vpnIp] + if !ok { + return false + } + newRelay := *r + newRelay.State = Established + newRelay.RemoteIndex = remoteIdx + rs.relayForByIdx[r.LocalIndex] = &newRelay + rs.relayForByIp[r.PeerIp] = &newRelay + return true +} + +func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Relay, bool) { + rs.Lock() + defer rs.Unlock() + r, ok := rs.relayForByIdx[localIdx] + if !ok { + return nil, false + } + newRelay := *r + newRelay.State = Established + newRelay.RemoteIndex = remoteIdx + rs.relayForByIdx[r.LocalIndex] = &newRelay + rs.relayForByIp[r.PeerIp] = &newRelay + return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { @@ -141,6 +188,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { r, ok := rs.relayForByIdx[idx] return r, ok } + func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() @@ -149,24 +197,24 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - sync.RWMutex - - remote *udp.Addr - remotes *RemoteList - promoteCounter atomic.Uint32 - ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry - packetStore []*cachedPacket //todo: this is other handshake manager entry - remoteIndexId uint32 - localIndexId uint32 - vpnIp iputil.VpnIp - recvError int - remoteCidr *cidr.Tree4 - relayState RelayState + remote *udp.Addr + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + remoteIndexId uint32 + localIndexId uint32 + vpnIp iputil.VpnIp + recvError atomic.Uint32 + remoteCidr *cidr.Tree4[struct{}] + relayState RelayState + + // HandshakePacket records the packets used to create this hostinfo + // We need these to avoid replayed handshake packets creating new hostinfos which causes churn + HandshakePacket map[uint8][]byte + + // nextLHQuery is the earliest we can ask the lighthouse for new information. + // This is used to limit lighthouse re-queries in chatty clients + nextLHQuery atomic.Int64 // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like @@ -180,6 +228,10 @@ type HostInfo struct { lastRoam time.Time lastRoamRemote *udp.Addr + + // Used to track other hostinfos for this vpn ip since only 1 can be primary + // Synchronised via hostmap lock and not the hostinfo lock. + next, prev *HostInfo } type ViaSender struct { @@ -202,26 +254,57 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { - h := map[iputil.VpnIp]*HostInfo{} - i := map[uint32]*HostInfo{} - r := map[uint32]*HostInfo{} - relays := map[uint32]*HostInfo{} - m := HostMap{ - name: name, - Indexes: i, - Relays: relays, - RemoteIndexes: r, - Hosts: h, - preferredRanges: preferredRanges, - vpnCIDR: vpnCIDR, - l: l, +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { + hm := newHostMap(l, vpnCIDR) + + hm.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + hm.reload(c, false) + }) + + l.WithField("network", hm.vpnCIDR.String()). + WithField("preferredRanges", hm.GetPreferredRanges()). + Info("Main HostMap created") + + return hm +} + +func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { + return &HostMap{ + Indexes: map[uint32]*HostInfo{}, + Relays: map[uint32]*HostInfo{}, + RemoteIndexes: map[uint32]*HostInfo{}, + Hosts: map[iputil.VpnIp]*HostInfo{}, + vpnCIDR: vpnCIDR, + l: l, + } +} + +func (hm *HostMap) reload(c *config.C, initial bool) { + if initial || c.HasChanged("preferred_ranges") { + var preferredRanges []*net.IPNet + rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) + + for _, rawPreferredRange := range rawPreferredRanges { + _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + + if err != nil { + hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + continue + } + + preferredRanges = append(preferredRanges, preferredRange) + } + + oldRanges := hm.preferredRanges.Swap(&preferredRanges) + if !initial { + hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + } } - return &m } -// UpdateStats takes a name and reports host and index counts to the stats collection system -func (hm *HostMap) EmitStats(name string) { +// EmitStats reports host, index, and relay counts to the stats collection system +func (hm *HostMap) EmitStats() { hm.RLock() hostLen := len(hm.Hosts) indexLen := len(hm.Indexes) @@ -229,361 +312,249 @@ func (hm *HostMap) EmitStats(name string) { relaysLen := len(hm.Relays) hm.RUnlock() - metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen)) + metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen)) + metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen)) + metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } func (hm *HostMap) RemoveRelay(localIdx uint32) { hm.Lock() - hiRelay, ok := hm.Relays[localIdx] + _, ok := hm.Relays[localIdx] if !ok { hm.Unlock() return } delete(hm.Relays, localIdx) hm.Unlock() - ip, ok := hiRelay.relayState.RemoveRelay(localIdx) - if !ok { - return - } - hiPeer, err := hm.QueryVpnIp(ip) - if err != nil { - return - } - var otherPeerIdx uint32 - hiPeer.relayState.DeleteRelay(hiRelay.vpnIp) - relay, ok := hiPeer.relayState.GetRelayForByIp(hiRelay.vpnIp) - if ok { - otherPeerIdx = relay.LocalIndex - } - // I am a relaying host. I need to remove the other relay, too. - hm.RemoveRelay(otherPeerIdx) -} - -func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) { - hm.RLock() - if i, ok := hm.Hosts[vpnIp]; ok { - index := i.localIndexId - hm.RUnlock() - return index, nil - } - hm.RUnlock() - return 0, errors.New("vpn IP not found") } -func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) { +// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip +func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { + // Delete the host itself, ensuring it's not modified anymore hm.Lock() - hm.Hosts[ip] = hostinfo + // If we have a previous or next hostinfo then we are not the last one for this vpn ip + final := (hostinfo.next == nil && hostinfo.prev == nil) + hm.unlockedDeleteHostInfo(hostinfo) hm.Unlock() -} -func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) { - hm.RLock() - if h, ok := hm.Hosts[vpnIp]; !ok { - hm.RUnlock() - h = &HostInfo{ - vpnIp: vpnIp, - HandshakePacket: make(map[uint8][]byte, 0), - relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - if init != nil { - init(h) - } - hm.Lock() - hm.Hosts[vpnIp] = h - hm.Unlock() - return h, true - } else { - hm.RUnlock() - return h, false - } + return final } -func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { hm.Lock() - delete(hm.Hosts, vpnIp) - if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} - } - hm.Unlock() - - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}). - Debug("Hostmap vpnIp deleted") - } + defer hm.Unlock() + hm.unlockedMakePrimary(hostinfo) } -// Only used by pendingHostMap when the remote index is not initially known -func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - hm.Lock() - h.remoteIndexId = index - hm.RemoteIndexes[index] = h - hm.Unlock() - - if hm.l.Level > logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}). - Debug("Hostmap remoteIndex added") +func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { + oldHostinfo := hm.Hosts[hostinfo.vpnIp] + if oldHostinfo == hostinfo { + return } -} -func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) { - hm.Lock() - h.vpnIp = vpnIp - hm.Hosts[vpnIp] = h - hm.Indexes[h.localIndexId] = h - hm.RemoteIndexes[h.remoteIndexId] = h - hm.Unlock() + if hostinfo.prev != nil { + hostinfo.prev.next = hostinfo.next + } - if hm.l.Level > logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}). - Debug("Hostmap vpnIp added") + if hostinfo.next != nil { + hostinfo.next.prev = hostinfo.prev } -} -// This is only called in pendingHostmap, to cleanup an inbound handshake -func (hm *HostMap) DeleteIndex(index uint32) { - hm.Lock() - hostinfo, ok := hm.Indexes[index] - if ok { - delete(hm.Indexes, index) - delete(hm.RemoteIndexes, hostinfo.remoteIndexId) + hm.Hosts[hostinfo.vpnIp] = hostinfo - // Check if we have an entry under hostId that matches the same hostinfo - // instance. Clean it up as well if we do. - hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.vpnIp) - } + if oldHostinfo == nil { + return } - hm.Unlock() - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). - Debug("Hostmap index deleted") - } + hostinfo.next = oldHostinfo + oldHostinfo.prev = hostinfo + hostinfo.prev = nil } -// This is used to cleanup on recv_error -func (hm *HostMap) DeleteReverseIndex(index uint32) { - hm.Lock() - hostinfo, ok := hm.RemoteIndexes[index] - if ok { - delete(hm.Indexes, hostinfo.localIndexId) - delete(hm.RemoteIndexes, index) - - // Check if we have an entry under hostId that matches the same hostinfo - // instance. Clean it up as well if we do (they might not match in pendingHostmap) - var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.vpnIp) +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { + primary, ok := hm.Hosts[hostinfo.vpnIp] + if ok && primary == hostinfo { + // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, hostinfo.vpnIp) + if len(hm.Hosts) == 0 { + hm.Hosts = map[iputil.VpnIp]*HostInfo{} } - } - hm.Unlock() - - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). - Debug("Hostmap remote index deleted") - } -} -func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { - // Delete the host itself, ensuring it's not modified anymore - hm.Lock() - hm.unlockedDeleteHostInfo(hostinfo) - hm.Unlock() + if hostinfo.next != nil { + // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary + hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // It is primary, there is no previous hostinfo now + hostinfo.next.prev = nil + } - // And tear down all the relays going through this host - for _, localIdx := range hostinfo.relayState.CopyRelayForIdxs() { - hm.RemoveRelay(localIdx) - } + } else { + // Relink if we were in the middle of multiple hostinfos for this vpn ip + if hostinfo.prev != nil { + hostinfo.prev.next = hostinfo.next + } - // And tear down the relays this deleted hostInfo was using to be reached - teardownRelayIdx := []uint32{} - for _, relayIp := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, err := hm.QueryVpnIp(relayIp) - if err != nil { - hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap") - } else { - if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok { - teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex) - } + if hostinfo.next != nil { + hostinfo.next.prev = hostinfo.prev } } - for _, localIdx := range teardownRelayIdx { - hm.RemoveRelay(localIdx) - } -} -func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { - hm.Lock() - defer hm.Unlock() - delete(hm.RemoteIndexes, localIdx) -} + hostinfo.next = nil + hostinfo.prev = nil -func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - // Check if this same hostId is in the hostmap with a different instance. - // This could happen if we have an entry in the pending hostmap with different - // index values than the one in the main hostmap. - hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 != hostinfo { - delete(hm.Hosts, hostinfo2.vpnIp) - delete(hm.Indexes, hostinfo2.localIndexId) - delete(hm.RemoteIndexes, hostinfo2.remoteIndexId) + // The remote index uses index ids outside our control so lets make sure we are only removing + // the remote index pointer here if it points to the hostinfo we are deleting + hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId] + if ok && hostinfo2 == hostinfo { + delete(hm.RemoteIndexes, hostinfo.remoteIndexId) + if len(hm.RemoteIndexes) == 0 { + hm.RemoteIndexes = map[uint32]*HostInfo{} + } } - delete(hm.Hosts, hostinfo.vpnIp) - if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} - } delete(hm.Indexes, hostinfo.localIndexId) if len(hm.Indexes) == 0 { hm.Indexes = map[uint32]*HostInfo{} } - delete(hm.RemoteIndexes, hostinfo.remoteIndexId) - if len(hm.RemoteIndexes) == 0 { - hm.RemoteIndexes = map[uint32]*HostInfo{} - } if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + + for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { + delete(hm.Relays, localRelayIdx) + } } -func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error +func (hm *HostMap) QueryIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Indexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error + +func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { +func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.RemoteIndexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) + return nil } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) { +func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } -// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every -// `PromoteEvery` calls to this function for a given host. -func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) { - return hm.queryVpnIp(vpnIp, ifce) +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { + hm.RLock() + defer hm.RUnlock() + + h, ok := hm.Hosts[relayHostIp] + if !ok { + return nil, nil, errors.New("unable to find host") + } + for h != nil { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } + h = h.next + } + return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) { +func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() // Do not attempt promotion if you are a lighthouse if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { - h.TryPromoteBest(hm.preferredRanges, promoteIfce) + h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce) } - return h, nil + return h } hm.RUnlock() - return nil, errors.New("unable to find host") + return nil } -// We already have the hm Lock when this is called, so make sure to not call -// any other methods that might try to grab it again -func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { +// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. +// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary +func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } + existing := hm.Hosts[hostinfo.vpnIp] hm.Hosts[hostinfo.vpnIp] = hostinfo + + if existing != nil { + hostinfo.next = existing + existing.prev = hostinfo + } + hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). Debug("Hostmap vpnIp added") } + + i := 1 + check := hostinfo + for check != nil { + if i > MaxHostInfosPerVpnIp { + hm.unlockedDeleteHostInfo(check) + } + check = check.next + i++ + } } -// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap -// The caller can then do the its work outside of the read lock -func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList { +func (hm *HostMap) GetPreferredRanges() []*net.IPNet { + //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer + return *hm.preferredRanges.Load() +} + +func (hm *HostMap) ForEachVpnIp(f controlEach) { hm.RLock() defer hm.RUnlock() for _, v := range hm.Hosts { - if v.remotes != nil { - rl = append(rl, v.remotes) - } + f(v) } - return rl } -// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them -func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { - var metricsTxPunchy metrics.Counter - if hm.metricsEnabled { - metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) - } else { - metricsTxPunchy = metrics.NilCounter{} - } - - var remotes []*RemoteList - b := []byte{1} - - clockSource := time.NewTicker(time.Second * 10) - defer clockSource.Stop() - - for { - remotes = hm.punchList(remotes[:0]) - for _, rl := range remotes { - //TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better - for _, addr := range rl.CopyAddrs(hm.preferredRanges) { - metricsTxPunchy.Inc(1) - conn.WriteTo(b, addr) - } - } +func (hm *HostMap) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - select { - case <-ctx.Done(): - return - case <-clockSource.C: - continue - } + for _, v := range hm.Indexes { + f(v) } } @@ -591,11 +562,8 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { c := i.promoteCounter.Add(1) - if c%PromoteEvery == 0 { - // The lock here is currently protecting i.remote access - i.RLock() + if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote - i.RUnlock() // return early if we are already on a preferred remote if remote != nil { @@ -619,66 +587,17 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } // Re query our lighthouses for new remotes occasionally - if c%ReQueryEvery == 0 && ifce.lightHouse != nil { - ifce.lightHouse.QueryServer(i.vpnIp, ifce) - } -} - -func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - //TODO: return the error so we can log with more context - if len(i.packetStore) < 100 { - tempPacket := make([]byte, len(packet)) - copy(tempPacket, packet) - //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) - i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil { + now := time.Now().UnixNano() + if now < i.nextLHQuery.Load() { + return } - } else if l.Level >= logrus.DebugLevel { - m.dropped.Inc(1) - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", false). - Debugf("Packet store") + i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) + ifce.lightHouse.QueryServer(i.vpnIp) } } -// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets -func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { - //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: - //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send - //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical - - i.ConnectionState.queueLock.Lock() - i.HandshakeComplete = true - //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. - // Clamping it to 2 gets us out of the woods for now - i.ConnectionState.messageCounter.Store(2) - - if l.Level >= logrus.DebugLevel { - i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) - } - - if len(i.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range i.packetStore { - cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out) - } - m.sent.Inc(int64(len(i.packetStore))) - } - - i.remotes.ResetBlockedRemotes() - i.packetStore = make([]*cachedPacket, 0) - i.ConnectionState.ready = true - i.ConnectionState.queueLock.Unlock() - i.ConnectionState.certState = nil -} - func (i *HostInfo) GetCert() *cert.NebulaCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert @@ -710,7 +629,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { // NOTE: We do this loop here instead of calling `isPreferred` in // remote_list.go so that we only have to loop over preferredRanges once. newIsPreferred := false - for _, l := range hm.preferredRanges { + for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote if l.Contains(currentRemote.IP) { return false @@ -735,9 +654,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { } func (i *HostInfo) RecvErrorExceeded() bool { - if i.recvError < 3 { - i.recvError += 1 - return false + if i.recvError.Add(1) >= maxRecvError { + return true } return true } @@ -748,7 +666,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4() + remoteCidr := cidr.NewTree4[struct{}]() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } diff --git a/hostmap_test.go b/hostmap_test.go index 2808317c2..8311cef0b 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1 +1,236 @@ package nebula + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +func TestHostMap_MakePrimary(t *testing.T) { + l := test.NewLogger() + hm := newHostMap( + l, + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + ) + + f := &Interface{} + + h1 := &HostInfo{vpnIp: 1, localIndexId: 1} + h2 := &HostInfo{vpnIp: 1, localIndexId: 2} + h3 := &HostInfo{vpnIp: 1, localIndexId: 3} + h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + + hm.unlockedAddHostInfo(h4, f) + hm.unlockedAddHostInfo(h3, f) + hm.unlockedAddHostInfo(h2, f) + hm.unlockedAddHostInfo(h1, f) + + // Make sure we go h1 -> h2 -> h3 -> h4 + prim := hm.QueryVpnIp(1) + assert.Equal(t, h1.localIndexId, prim.localIndexId) + assert.Equal(t, h2.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Swap h3/middle to primary + hm.MakePrimary(h3) + + // Make sure we go h3 -> h1 -> h2 -> h4 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h3.localIndexId, prim.localIndexId) + assert.Equal(t, h1.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Swap h4/tail to primary + hm.MakePrimary(h4) + + // Make sure we go h4 -> h3 -> h1 -> h2 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h3.next.localIndexId) + assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Nil(t, h2.next) + + // Swap h4 again should be no-op + hm.MakePrimary(h4) + + // Make sure we go h4 -> h3 -> h1 -> h2 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h3.next.localIndexId) + assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Nil(t, h2.next) +} + +func TestHostMap_DeleteHostInfo(t *testing.T) { + l := test.NewLogger() + hm := newHostMap( + l, + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + ) + + f := &Interface{} + + h1 := &HostInfo{vpnIp: 1, localIndexId: 1} + h2 := &HostInfo{vpnIp: 1, localIndexId: 2} + h3 := &HostInfo{vpnIp: 1, localIndexId: 3} + h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h5 := &HostInfo{vpnIp: 1, localIndexId: 5} + h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + + hm.unlockedAddHostInfo(h6, f) + hm.unlockedAddHostInfo(h5, f) + hm.unlockedAddHostInfo(h4, f) + hm.unlockedAddHostInfo(h3, f) + hm.unlockedAddHostInfo(h2, f) + hm.unlockedAddHostInfo(h1, f) + + // h6 should be deleted + assert.Nil(t, h6.next) + assert.Nil(t, h6.prev) + h := hm.QueryIndex(h6.localIndexId) + assert.Nil(t, h) + + // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 + prim := hm.QueryVpnIp(1) + assert.Equal(t, h1.localIndexId, prim.localIndexId) + assert.Equal(t, h2.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete primary + hm.DeleteHostInfo(h1) + assert.Nil(t, h1.prev) + assert.Nil(t, h1.next) + + // Make sure we go h2 -> h3 -> h4 -> h5 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete in the middle + hm.DeleteHostInfo(h3) + assert.Nil(t, h3.prev) + assert.Nil(t, h3.next) + + // Make sure we go h2 -> h4 -> h5 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h4.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete the tail + hm.DeleteHostInfo(h5) + assert.Nil(t, h5.prev) + assert.Nil(t, h5.next) + + // Make sure we go h2 -> h4 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h4.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Delete the head + hm.DeleteHostInfo(h2) + assert.Nil(t, h2.prev) + assert.Nil(t, h2.next) + + // Make sure we only have h4 + prim = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Nil(t, prim.prev) + assert.Nil(t, prim.next) + assert.Nil(t, h4.next) + + // Delete the only item + hm.DeleteHostInfo(h4) + assert.Nil(t, h4.prev) + assert.Nil(t, h4.next) + + // Make sure we have nil + prim = hm.QueryVpnIp(1) + assert.Nil(t, prim) +} + +func TestHostMap_reload(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + + hm := NewHostMapFromConfig( + l, + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + c, + ) + + toS := func(ipn []*net.IPNet) []string { + var s []string + for _, n := range ipn { + s = append(s, n.String()) + } + return s + } + + assert.Empty(t, hm.GetPreferredRanges()) + + c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") + assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + + c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") + assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) +} diff --git a/hostmap_tester.go b/hostmap_tester.go index 1d4323fd9..0d5d41bf7 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -19,6 +19,6 @@ func (i *HostInfo) GetRemoteIndex() uint32 { return i.remoteIndexId } -func (i *HostInfo) GetRelayState() RelayState { - return i.relayState +func (i *HostInfo) GetRelayState() *RelayState { + return &i.relayState } diff --git a/inside.go b/inside.go index 38d9332c2..079e4dd2f 100644 --- a/inside.go +++ b/inside.go @@ -1,11 +1,11 @@ package nebula import ( - "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/udp" ) @@ -44,8 +44,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo := f.getOrHandshake(fwPacket.RemoteIP) + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + if hostinfo == nil { + f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", fwPacket.RemoteIP). WithField("fwPacket", fwPacket). @@ -53,95 +57,83 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } return } - ci := hostinfo.ConnectionState - - if ci.ready == false { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - ci.queueLock.Lock() - if !ci.ready { - hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - ci.queueLock.Unlock() - return - } - ci.queueLock.Unlock() + + if !ready { + return } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) - } else if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") + } else { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } } } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { - f.getOrHandshake(vpnIp) -} +func (f *Interface) rejectInside(packet []byte, out []byte, q int) { + if !f.firewall.InSendReject { + return + } -// getOrHandshake returns nil if the vpnIp is not routable -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { - vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { - return nil - } + out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return } - hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) - //if err != nil || hostinfo.ConnectionState == nil { + _, err := f.readers[q].Write(out) if err != nil { - hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) - } + f.l.WithError(err).Error("Failed to write to tun") } - ci := hostinfo.ConnectionState +} - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo +func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, nb, out []byte, q int) { + if !f.firewall.OutSendReject { + return } - // Handshake is not ready, we need to grab the lock now before we start the handshake process - hostinfo.Lock() - defer hostinfo.Unlock() - - // Double check, now that we have the lock - ci = hostinfo.ConnectionState - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo + out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return } - // If we have already created the handshake packet, we don't want to call the function at all. - if !hostinfo.HandshakeReady { - ixHandshakeStage0(f, vpnIp, hostinfo) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //xx_handshakeStage0(f, ip, hostinfo) - - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok { - select { - case f.handshakeManager.trigger <- vpnIp: - default: - } + if len(out) > iputil.MaxRejectPacketSize { + if f.l.GetLevel() >= logrus.InfoLevel { + f.l. + WithField("packet", packet). + WithField("outPacket", out). + Info("rejectOutside: packet too big, not sending") } + return } - return hostinfo + f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) +} + +func (f *Interface) Handshake(vpnIp iputil.VpnIp) { + f.getOrHandshake(vpnIp, nil) } -// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that -// will create the initial Noise ConnectionState -func (f *Interface) initHostInfo(hostinfo *HostInfo) { - hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) +// getOrHandshake returns nil if the vpnIp is not routable. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { + vpnIp = f.inside.RouteFor(vpnIp) + if vpnIp == 0 { + return nil, false + } + } + + return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) } -func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { +func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { @@ -150,7 +142,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil) + dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). @@ -160,12 +152,15 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo := f.getOrHandshake(vpnIp) + hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) + }) + if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", vpnIp). @@ -174,24 +169,15 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu return } - if !hostInfo.ConnectionState.ready { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - hostInfo.ConnectionState.queueLock.Lock() - if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp, f.cachedPacketMetrics) - hostInfo.ConnectionState.queueLock.Unlock() - return - } - hostInfo.ConnectionState.queueLock.Unlock() + if !ready { + return } - f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out) - return + f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) } -func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { - f.send(t, st, hostInfo.ConnectionState, hostInfo, p, nb, out) +func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) { + f.send(t, st, hi.ConnectionState, hi, p, nb, out) } func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { @@ -204,7 +190,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done +// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // to the payload for the ultimate target host, making this a useful method for sending // handshake messages to peers through relay tunnels. // via is the HostInfo through which the message is relayed. @@ -212,15 +198,17 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C // nb is a buffer used to store the nonce value, re-used for performance reasons. // out is a buffer used to store the result of the Encrypt operation // q indicates which writer to use to send the packet. -func (f *Interface) SendVia(viaIfc interface{}, - relayIfc interface{}, +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, ad, nb, out []byte, nocopy bool, ) { - via := viaIfc.(*HostInfo) - relay := relayIfc.(*Relay) + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + via.ConnectionState.writeLock.Lock() + } c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) @@ -229,6 +217,9 @@ func (f *Interface) SendVia(viaIfc interface{}, // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) { + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } via.logger(f.l). WithField("outCap", cap(out)). WithField("payloadLen", len(ad)). @@ -250,6 +241,9 @@ func (f *Interface) SendVia(viaIfc interface{}, var err error out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb) + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } if err != nil { via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return @@ -258,6 +252,7 @@ func (f *Interface) SendVia(viaIfc interface{}, if err != nil { via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") } + f.connectionManager.RelayUsed(relay.LocalIndex) } func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { @@ -278,8 +273,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType out = out[header.Len:] } - //TODO: enable if we do more than 1 tun queue - //ci.writeLock.Lock() + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + ci.writeLock.Lock() + } c := ci.messageCounter.Add(1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) @@ -291,7 +288,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp, f) + f.lightHouse.QueryServer(hostinfo.vpnIp) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") @@ -300,8 +297,9 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType var err error out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) - //TODO: see above note on lock - //ci.writeLock.Unlock() + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).WithField("counter", c). @@ -325,31 +323,19 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, err := f.hostMap.QueryVpnIp(relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) if err != nil { - hostinfo.logger(f.l).WithField("relayIp", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") - continue - } - relay, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp) - if !ok { - hostinfo.logger(f.l). - WithField("relayIp", relayHostInfo.vpnIp). - WithField("relayTarget", hostinfo.vpnIp). - Info("sendNoMetrics relay missing object for target") + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) break } } - return } func isMulticast(ip iputil.VpnIp) bool { // Class D multicast - if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 { - return true - } - - return false + return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 } diff --git a/interface.go b/interface.go index 632e823f8..d16348aac 100644 --- a/interface.go +++ b/interface.go @@ -13,9 +13,9 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -25,24 +25,27 @@ const mtu = 9001 type InterfaceConfig struct { HostMap *HostMap - Outside *udp.Conn + Outside udp.Conn Inside overlay.Device - certState *CertState + pki *PKI Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager lightHouse *LightHouse - checkInterval int - pendingDeletionInterval int + checkInterval time.Duration + pendingDeletionInterval time.Duration DropLocalBroadcast bool DropMulticast bool routines int MessageMetrics *MessageMetrics version string - caPool *cert.NebulaCAPool - disconnectInvalid bool relayManager *relayManager + punchy *Punchy + + tryPromoteEvery uint32 + reQueryEvery uint32 + reQueryWait time.Duration ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -50,9 +53,9 @@ type InterfaceConfig struct { type Interface struct { hostMap *HostMap - outside *udp.Conn + outside udp.Conn inside overlay.Device - certState *CertState + pki *PKI cipher string firewall *Firewall connectionManager *connectionManager @@ -65,11 +68,14 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool routines int - caPool *cert.NebulaCAPool - disconnectInvalid bool + disconnectInvalid atomic.Bool closed atomic.Bool relayManager *relayManager + tryPromoteEvery atomic.Uint32 + reQueryEvery atomic.Uint32 + reQueryWait atomic.Int64 + sendRecvErrorConfig sendRecvErrorConfig // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse @@ -78,7 +84,7 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []*udp.Conn + writers []udp.Conn readers []io.ReadWriteCloser metricHandshakes metrics.Histogram @@ -88,6 +94,19 @@ type Interface struct { l *logrus.Logger } +type EncWriter interface { + SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, + ) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) + Handshake(vpnIp iputil.VpnIp) +} + type sendRecvErrorConfig uint8 const ( @@ -129,34 +148,33 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Inside == nil { return nil, errors.New("no inside interface (tun)") } - if c.certState == nil { + if c.pki == nil { return nil, errors.New("no certificate state") } if c.Firewall == nil { return nil, errors.New("no firewall rules") } - myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) + certificate := c.pki.GetCertState().Certificate + myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) ifce := &Interface{ + pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, - certState: c.certState, cipher: c.Cipher, firewall: c.Firewall, serveDns: c.ServeDns, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask), + localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, - writers: make([]*udp.Conn, c.routines), + writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - caPool: c.caPool, - disconnectInvalid: c.disconnectInvalid, myVpnIp: myVpnIp, relayManager: c.relayManager, @@ -172,7 +190,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval) + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) + ifce.reQueryEvery.Store(c.reQueryEvery) + ifce.reQueryWait.Store(int64(c.reQueryWait)) + + ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) return ifce, nil } @@ -190,6 +212,7 @@ func (f *Interface) activate() { f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). WithField("build", f.version).WithField("udpAddr", addr). + WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) @@ -227,7 +250,7 @@ func (f *Interface) run() { func (f *Interface) listenOut(i int) { runtime.LockOSThread() - var li *udp.Conn + var li udp.Conn // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] @@ -237,7 +260,7 @@ func (f *Interface) listenOut(i int) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) + li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -267,46 +290,24 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { - c.RegisterReloadCallback(f.reloadCA) - c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) + c.RegisterReloadCallback(f.reloadDisconnectInvalid) + c.RegisterReloadCallback(f.reloadMisc) + for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) } } -func (f *Interface) reloadCA(c *config.C) { - // reload and check regardless - // todo: need mutex? - newCAs, err := loadCAFromConfig(f.l, c) - if err != nil { - f.l.WithError(err).Error("Could not refresh trusted CA certificates") - return - } - - f.caPool = newCAs - f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") -} - -func (f *Interface) reloadCertKey(c *config.C) { - // reload and check in all cases - cs, err := NewCertStateFromConfig(c) - if err != nil { - f.l.WithError(err).Error("Could not refresh client cert") - return - } - - // did IP in cert change? if so, don't set - oldIPs := f.certState.certificate.Details.Ips - newIPs := cs.certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { - f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") - return +func (f *Interface) reloadDisconnectInvalid(c *config.C) { + initial := c.InitialLoad() + if initial || c.HasChanged("pki.disconnect_invalid") { + f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) + if !initial { + f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + } } - - f.certState = cs - f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") } func (f *Interface) reloadFirewall(c *config.C) { @@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -331,8 +332,8 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHash", fw.GetRuleHash()). - WithField("oldFirewallHash", oldFw.GetRuleHash()). + f.l.WithField("firewallHashes", fw.GetRuleHashes()). + WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Warn("firewall rulesVersion has overflowed, resetting conntrack") } else { @@ -342,8 +343,8 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHash", fw.GetRuleHash()). - WithField("oldFirewallHash", oldFw.GetRuleHash()). + f.l.WithField("firewallHashes", fw.GetRuleHashes()). + WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") } @@ -372,12 +373,34 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } +func (f *Interface) reloadMisc(c *config.C) { + if c.HasChanged("counters.try_promote") { + n := c.GetUint32("counters.try_promote", defaultPromoteEvery) + f.tryPromoteEvery.Store(n) + f.l.Info("counters.try_promote has changed") + } + + if c.HasChanged("counters.requery_every_packets") { + n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery) + f.reQueryEvery.Store(n) + f.l.Info("counters.requery_every_packets has changed") + } + + if c.HasChanged("timers.requery_wait_duration") { + n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait) + f.reQueryWait.Store(int64(n)) + f.l.Info("timers.requery_wait_duration has changed") + } +} + func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() udpStats := udp.NewUDPStatsEmitter(f.writers) + certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + for { select { case <-ctx.Done(): @@ -386,6 +409,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } @@ -393,6 +417,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { func (f *Interface) Close() error { f.closed.Store(true) + for _, u := range f.writers { + err := u.Close() + if err != nil { + f.l.WithError(err).Error("Error while closing udp socket") + } + } + // Release the tun device return f.inside.Close() } diff --git a/iputil/packet.go b/iputil/packet.go new file mode 100644 index 000000000..b18e52447 --- /dev/null +++ b/iputil/packet.go @@ -0,0 +1,238 @@ +package iputil + +import ( + "encoding/binary" + + "golang.org/x/net/ipv4" +) + +const ( + // Need 96 bytes for the largest reject packet: + // - 20 byte ipv4 header + // - 8 byte icmpv4 header + // - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header) + MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8 +) + +func CreateRejectPacket(packet []byte, out []byte) []byte { + if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version { + return nil + } + + switch packet[9] { + case 6: // tcp + return ipv4CreateRejectTCPPacket(packet, out) + default: + return ipv4CreateRejectICMPPacket(packet, out) + } +} + +func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte { + ihl := int(packet[0]&0x0f) << 2 + + if len(packet) < ihl { + // We need at least this many bytes for this to be a valid packet + return nil + } + + // ICMP reply includes original header and first 8 bytes of the packet + packetLen := len(packet) + if packetLen > ihl+8 { + packetLen = ihl + 8 + } + + outLen := ipv4.HeaderLen + 8 + packetLen + if outLen > cap(out) { + return nil + } + + out = out[:outLen] + + ipHdr := out[0:ipv4.HeaderLen] + ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl + ipHdr[1] = 0 // DSCP, ECN + binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length + + ipHdr[4] = 0 // id + ipHdr[5] = 0 // . + ipHdr[6] = 0 // flags, fragment offset + ipHdr[7] = 0 // . + ipHdr[8] = 64 // TTL + ipHdr[9] = 1 // protocol (icmp) + ipHdr[10] = 0 // checksum + ipHdr[11] = 0 // . + + // Swap dest / src IPs + copy(ipHdr[12:16], packet[16:20]) + copy(ipHdr[16:20], packet[12:16]) + + // Calculate checksum + binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) + + // ICMP Destination Unreachable + icmpOut := out[ipv4.HeaderLen:] + icmpOut[0] = 3 // type (Destination unreachable) + icmpOut[1] = 3 // code (Port unreachable error) + icmpOut[2] = 0 // checksum + icmpOut[3] = 0 // . + icmpOut[4] = 0 // unused + icmpOut[5] = 0 // . + icmpOut[6] = 0 // . + icmpOut[7] = 0 // . + + // Copy original IP header and first 8 bytes as body + copy(icmpOut[8:], packet[:packetLen]) + + // Calculate checksum + binary.BigEndian.PutUint16(icmpOut[2:], tcpipChecksum(icmpOut, 0)) + + return out +} + +func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte { + const tcpLen = 20 + + ihl := int(packet[0]&0x0f) << 2 + outLen := ipv4.HeaderLen + tcpLen + + if len(packet) < ihl+tcpLen { + // We need at least this many bytes for this to be a valid packet + return nil + } + if outLen > cap(out) { + return nil + } + + out = out[:outLen] + + ipHdr := out[0:ipv4.HeaderLen] + ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl + ipHdr[1] = 0 // DSCP, ECN + binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length + ipHdr[4] = 0 // id + ipHdr[5] = 0 // . + ipHdr[6] = 0 // flags, fragment offset + ipHdr[7] = 0 // . + ipHdr[8] = 64 // TTL + ipHdr[9] = 6 // protocol (tcp) + ipHdr[10] = 0 // checksum + ipHdr[11] = 0 // . + + // Swap dest / src IPs + copy(ipHdr[12:16], packet[16:20]) + copy(ipHdr[16:20], packet[12:16]) + + // Calculate checksum + binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) + + // TCP RST + tcpIn := packet[ihl:] + var ackSeq, seq uint32 + outFlags := byte(0b00000100) // RST + + // Set seq and ackSeq based on how iptables/netfilter does it in Linux: + // - https://github.com/torvalds/linux/blob/v5.19/net/ipv4/netfilter/nf_reject_ipv4.c#L193-L221 + inAck := tcpIn[13]&0b00010000 != 0 + if inAck { + seq = binary.BigEndian.Uint32(tcpIn[8:]) + } else { + inSyn := uint32((tcpIn[13] & 0b00000010) >> 1) + inFin := uint32(tcpIn[13] & 0b00000001) + // seq from the packet + syn + fin + tcp segment length + ackSeq = binary.BigEndian.Uint32(tcpIn[4:]) + inSyn + inFin + uint32(len(tcpIn)) - uint32(tcpIn[12]>>4)<<2 + outFlags |= 0b00010000 // ACK + } + + tcpOut := out[ipv4.HeaderLen:] + // Swap dest / src ports + copy(tcpOut[0:2], tcpIn[2:4]) + copy(tcpOut[2:4], tcpIn[0:2]) + binary.BigEndian.PutUint32(tcpOut[4:], seq) + binary.BigEndian.PutUint32(tcpOut[8:], ackSeq) + tcpOut[12] = (tcpLen >> 2) << 4 // data offset, reserved, NS + tcpOut[13] = outFlags // CWR, ECE, URG, ACK, PSH, RST, SYN, FIN + tcpOut[14] = 0 // window size + tcpOut[15] = 0 // . + tcpOut[16] = 0 // checksum + tcpOut[17] = 0 // . + tcpOut[18] = 0 // URG Pointer + tcpOut[19] = 0 // . + + // Calculate checksum + csum := ipv4PseudoheaderChecksum(ipHdr[12:16], ipHdr[16:20], 6, tcpLen) + binary.BigEndian.PutUint16(tcpOut[16:], tcpipChecksum(tcpOut, csum)) + + return out +} + +func CreateICMPEchoResponse(packet, out []byte) []byte { + // Return early if this is not a simple ICMP Echo Request + //TODO: make constants out of these + if !(len(packet) >= 28 && len(packet) <= 9001 && packet[0] == 0x45 && packet[9] == 0x01 && packet[20] == 0x08) { + return nil + } + + // We don't support fragmented packets + if packet[7] != 0 || (packet[6]&0x2F != 0) { + return nil + } + + out = out[:len(packet)] + + copy(out, packet) + + // Swap dest / src IPs and recalculate checksum + ipv4 := out[0:20] + copy(ipv4[12:16], packet[16:20]) + copy(ipv4[16:20], packet[12:16]) + ipv4[10] = 0 + ipv4[11] = 0 + binary.BigEndian.PutUint16(ipv4[10:], tcpipChecksum(ipv4, 0)) + + // Change type to ICMP Echo Reply and recalculate checksum + icmp := out[20:] + icmp[0] = 0 + icmp[2] = 0 + icmp[3] = 0 + binary.BigEndian.PutUint16(icmp[2:], tcpipChecksum(icmp, 0)) + + return out +} + +// calculates the TCP/IP checksum defined in rfc1071. The passed-in +// csum is any initial checksum data that's already been computed. +// +// based on: +// - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L50-L70 +func tcpipChecksum(data []byte, csum uint32) uint16 { + // to handle odd lengths, we loop to length - 1, incrementing by 2, then + // handle the last byte specifically by checking against the original + // length. + length := len(data) - 1 + for i := 0; i < length; i += 2 { + // For our test packet, doing this manually is about 25% faster + // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16. + csum += uint32(data[i]) << 8 + csum += uint32(data[i+1]) + } + if len(data)%2 == 1 { + csum += uint32(data[length]) << 8 + } + for csum > 0xffff { + csum = (csum >> 16) + (csum & 0xffff) + } + return ^uint16(csum) +} + +// based on: +// - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L26-L35 +func ipv4PseudoheaderChecksum(src, dst []byte, proto, length uint32) (csum uint32) { + csum += (uint32(src[0]) + uint32(src[2])) << 8 + csum += uint32(src[1]) + uint32(src[3]) + csum += (uint32(dst[0]) + uint32(dst[2])) << 8 + csum += uint32(dst[1]) + uint32(dst[3]) + csum += proto + csum += length & 0xffff + csum += length >> 16 + return csum +} diff --git a/iputil/packet_test.go b/iputil/packet_test.go new file mode 100644 index 000000000..e1d0d95d8 --- /dev/null +++ b/iputil/packet_test.go @@ -0,0 +1,73 @@ +package iputil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/ipv4" +) + +func Test_CreateRejectPacket(t *testing.T) { + h := ipv4.Header{ + Len: 20, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + } + + b, err := h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + + expectedLen := ipv4.HeaderLen + 8 + h.Len + 4 + out := make([]byte, expectedLen) + rejectPacket := CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // ICMP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...) + + expectedLen = MaxRejectPacketSize + out = make([]byte, MaxRejectPacketSize) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // TCP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 6, // TCP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + b = append(b, make([]byte, 16)...) + + expectedLen = ipv4.HeaderLen + 20 + out = make([]byte, expectedLen) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) +} diff --git a/lighthouse.go b/lighthouse.go index 60e1f29e5..aa54c4bc5 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -6,12 +6,14 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -32,11 +34,12 @@ type netIpAndPort struct { type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps + ctx context.Context amLighthouse bool myVpnIp iputil.VpnIp myVpnZeros iputil.VpnIp myVpnNet *net.IPNet - punchConn *udp.Conn + punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses @@ -61,17 +64,20 @@ type LightHouse struct { staticList atomic.Pointer[map[iputil.VpnIp]struct{}] lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] - interval atomic.Int64 - updateCancel context.CancelFunc - updateParentCtx context.Context - updateUdp udp.EncWriter - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 + interval atomic.Int64 + updateCancel context.CancelFunc + ifce EncWriter + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netIpAndPort] // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] + queryChan chan iputil.VpnIp + + calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + metrics *MessageMetrics metricHolepunchTx metrics.Counter l *logrus.Logger @@ -79,7 +85,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -97,6 +103,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, ones, _ := myVpnNet.Mask.Size() h := LightHouse{ + ctx: ctx, amLighthouse: amLighthouse, myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), myVpnZeros: iputil.VpnIp(32 - ones), @@ -105,6 +112,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), l: l, } lighthouses := make(map[iputil.VpnIp]struct{}) @@ -127,13 +135,15 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, c.RegisterReloadCallback(func(c *config.C) { err := h.reload(c, false) switch v := err.(type) { - case util.ContextualError: + case *util.ContextualError: v.Log(l) case error: l.WithError(err).Error("failed to reload lighthouse") } }) + h.startQueryWorker() + return &h, nil } @@ -161,6 +171,10 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { return *lh.relaysForMe.Load() } +func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { + return lh.calculatedRemotes.Load() +} + func (lh *LightHouse) GetUpdateInterval() int64 { return lh.interval.Load() } @@ -207,7 +221,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.updateCancel() } - lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp) + lh.StartUpdateWorker() } } @@ -237,8 +251,33 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial || c.HasChanged("lighthouse.calculated_remotes") { + cr, err := NewCalculatedRemotesFromConfig(c, "lighthouse.calculated_remotes") + if err != nil { + return util.NewContextualError("Invalid lighthouse.calculated_remotes", nil, err) + } + + lh.calculatedRemotes.Store(cr) + if !initial { + //TODO: a diff will be annoyingly difficult + lh.l.Info("lighthouse.calculated_remotes has changed") + } + } + //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config - if initial || c.HasChanged("static_host_map") { + if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { + // Clean up. Entries still in the static_host_map will be re-built. + // Entries no longer present must have their (possible) background DNS goroutines stopped. + if existingStaticList := lh.staticList.Load(); existingStaticList != nil { + lh.RLock() + for staticVpnIp := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + am.hr.Cancel() + } + } + lh.RUnlock() + } + // Build a new list based on current config. staticList := make(map[iputil.VpnIp]struct{}) err := lh.loadStaticMap(c, lh.myVpnNet, staticList) if err != nil { @@ -248,9 +287,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.staticList.Store(&staticList) if !initial { //TODO: we should remove any remote list entries for static hosts that were removed/modified? - lh.l.Info("static_host_map has changed") + if c.HasChanged("static_host_map") { + lh.l.Info("static_host_map has changed") + } + if c.HasChanged("static_map.cadence") { + lh.l.Info("static_map.cadence has changed") + } + if c.HasChanged("static_map.network") { + lh.l.Info("static_map.network has changed") + } + if c.HasChanged("static_map.lookup_timeout") { + lh.l.Info("static_map.lookup_timeout has changed") + } } - } if initial || c.HasChanged("lighthouse.hosts") { @@ -279,7 +328,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { case false: relaysForMe := []iputil.VpnIp{} for _, v := range c.GetStringSlice("relay.relays", nil) { - lh.l.WithField("RelayIP", v).Info("Read relay from config") + lh.l.WithField("relay", v).Info("Read relay from config") configRIP := net.ParseIP(v) if configRIP != nil { @@ -324,7 +373,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma return nil } +func getStaticMapCadence(c *config.C) (time.Duration, error) { + cadence := c.GetString("static_map.cadence", "30s") + d, err := time.ParseDuration(cadence) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) { + lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms") + d, err := time.ParseDuration(lookupTimeout) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapNetwork(c *config.C) (string, error) { + network := c.GetString("static_map.network", "ip4") + if network != "ip" && network != "ip4" && network != "ip6" { + return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6") + } + return network, nil +} + func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { + d, err := getStaticMapCadence(c) + if err != nil { + return err + } + + network, err := getStaticMapNetwork(c) + if err != nil { + return err + } + + lookup_timeout, err := getStaticMapLookupTimeout(c) + if err != nil { + return err + } + shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) i := 0 @@ -340,21 +430,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) - if ok { - for _, v := range vals { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) - } + if !ok { + vals = []interface{}{v} + } + remoteAddrs := []string{} + for _, v := range vals { + remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) + } - } else { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) + err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + if err != nil { + return err } i++ } @@ -362,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip, f) + lh.QueryServer(ip) } lh.RLock() if v, ok := lh.addrMap[ip]; ok { @@ -375,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { return nil } -// This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { - if lh.amLighthouse { - return - } - - if lh.IsLighthouseIP(ip) { +// QueryServer is asynchronous so no reply should be expected +func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { + // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseIP(ip) { return } - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return - } - - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for n := range lighthouses { - f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) - } + lh.queryChan <- ip } func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { @@ -462,42 +532,121 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() defer am.Unlock() + ctx := lh.ctx lh.Unlock() - if ipv4 := toAddr.IP.To4(); ipv4 != nil { - to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV4(vpnIp, to) { - return - } - am.unlockedPrependV4(lh.myVpnIp, to) + hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { + // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // in its resolution for hostnames. + am.Lock() + defer am.Unlock() + am.shouldRebuild = true + }) + if err != nil { + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + } + am.unlockedSetHostnamesResults(hr) - } else { - to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV6(vpnIp, to) { - return + for _, addrPort := range hr.GetIPs() { + + switch { + case addrPort.Addr().Is4(): + to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV4(vpnIp, to) { + continue + } + am.unlockedPrependV4(lh.myVpnIp, to) + case addrPort.Addr().Is6(): + to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV6(vpnIp, to) { + continue + } + am.unlockedPrependV6(lh.myVpnIp, to) } - am.unlockedPrependV6(lh.myVpnIp, to) } // Mark it as static in the caller provided map staticList[vpnIp] = struct{}{} + return nil +} + +// addCalculatedRemotes adds any calculated remotes based on the +// lighthouse.calculated_remotes configuration. It returns true if any +// calculated remotes were added +func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { + tree := lh.getCalculatedRemotes() + if tree == nil { + return false + } + ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + if !ok { + return false + } + + var calculated []*Ip4AndPort + for _, cr := range calculatedRemotes { + c := cr.Apply(vpnIp) + if c != nil { + calculated = append(calculated, c) + } + } + + lh.Lock() + am := lh.unlockedGetRemoteList(vpnIp) + am.Lock() + defer am.Unlock() + lh.Unlock() + + am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + + return len(calculated) > 0 } // unlockedGetRemoteList assumes you have the lh lock func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { - am = NewRemoteList() + am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) lh.addrMap[vpnIp] = am } return am } +func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { + switch { + case to.Is4(): + ipBytes := to.As4() + ip := iputil.Ip2VpnIp(ipBytes[:]) + allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { + return false + } + case to.Is6(): + ipBytes := to.As16() + + hi := binary.BigEndian.Uint64(ipBytes[:8]) + lo := binary.BigEndian.Uint64(ipBytes[8:]) + allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + + // We don't check our vpn network here because nebula does not support ipv6 on the inside + if !allow { + return false + } + } + return true +} + // unlockedShouldAddV4 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) @@ -556,6 +705,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { return &ipp } +func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { + v4Addr := ip.As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { return &Ip6AndPort{ Hi: binary.BigEndian.Uint64(ip[:8]), @@ -564,6 +721,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { } } +func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { + ip6Addr := ip.As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(ip6Addr[:8]), + Lo: binary.BigEndian.Uint64(ip6Addr[8:]), + Port: uint32(port), + } +} func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { ip := ipp.Ip return udp.NewAddr( @@ -576,33 +741,73 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { - lh.updateParentCtx = ctx - lh.updateUdp = f +func (lh *LightHouse) startQueryWorker() { + if lh.amLighthouse { + return + } + + go func() { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for { + select { + case <-lh.ctx.Done(): + return + case ip := <-lh.queryChan: + lh.innerQueryServer(ip, nb, out) + } + } + }() +} + +func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { + if lh.IsLighthouseIP(ip) { + return + } + + // Send a query to the lighthouses and hope for the best next time + query, err := NewLhQueryByInt(ip).Marshal() + if err != nil { + lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") + return + } + + lighthouses := lh.GetLighthouses() + lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) + + for n := range lighthouses { + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + } +} + +func (lh *LightHouse) StartUpdateWorker() { interval := lh.GetUpdateInterval() if lh.amLighthouse || interval == 0 { return } clockSource := time.NewTicker(time.Second * time.Duration(interval)) - updateCtx, cancel := context.WithCancel(ctx) + updateCtx, cancel := context.WithCancel(lh.ctx) lh.updateCancel = cancel - defer clockSource.Stop() - for { - lh.SendUpdate(f) + go func() { + defer clockSource.Stop() - select { - case <-updateCtx.Done(): - return - case <-clockSource.C: - continue + for { + lh.SendUpdate() + + select { + case <-updateCtx.Done(): + return + case <-clockSource.C: + continue + } } - } + }() } -func (lh *LightHouse) SendUpdate(f udp.EncWriter) { +func (lh *LightHouse) SendUpdate() { var v4 []*Ip4AndPort var v6 []*Ip6AndPort @@ -655,7 +860,7 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) { } for vpnIp := range lighthouses { - f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) } } @@ -707,7 +912,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { +func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { + return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + lhh.HandleRequest(rAddr, vpnIp, p, f) + } +} + +func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -734,15 +945,18 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, lhh.handleHostQueryReply(n, vpnIp) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp) + lhh.handleHostUpdateNotification(n, vpnIp, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: lhh.handleHostPunchNotification(n, vpnIp, w) + + case NebulaMeta_HostUpdateNotificationAck: + // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -847,7 +1061,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.V } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -873,9 +1087,22 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) am.Unlock() + + n = lhh.resetMeta() + n.Type = NebulaMeta_HostUpdateNotificationAck + n.Details.VpnIp = uint32(vpnIp) + ln, err := n.MarshalTo(lhh.pb) + + if err != nil { + lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + return + } + + lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) + w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } @@ -912,7 +1139,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i if lhh.lh.punchy.GetRespond() { queryVpnIp := iputil.VpnIp(n.Details.VpnIp) go func() { - time.Sleep(time.Second * 5) + time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", queryVpnIp) } diff --git a/lighthouse_test.go b/lighthouse_test.go index e5a169281..66427e339 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "fmt" "net" "testing" @@ -11,6 +12,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" ) //TODO: Add a test to ensure udpAddr is copied and not reused @@ -53,30 +55,59 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } +func TestReloadLighthouseInterval(t *testing.T) { + l := test.NewLogger() + _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + lh1 := "10.128.0.2" + + c := config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "hosts": []interface{}{lh1}, + "interval": "1s", + } + + c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + assert.NoError(t, err) + lh.ifce = &mockEncWriter{} + + // The first one routine is kicked off by main.go currently, lets make sure that one dies + c.ReloadConfigString("lighthouse:\n interval: 5") + assert.Equal(t, int64(5), lh.interval.Load()) + + // Subsequent calls are killed off by the LightHouse.Reload function + c.ReloadConfigString("lighthouse:\n interval: 10") + assert.Equal(t, int64(10), lh.interval.Load()) + + // If this completes then nothing is stealing our reload routine + c.ReloadConfigString("lighthouse:\n interval: 11") + assert.Equal(t, int64(11), lh.interval.Load()) +} + func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") c := config.NewC(l) - lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) if !assert.NoError(b, err) { b.Fatal() } hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList() + lh.addrMap[3] = NewRemoteList(nil) lh.addrMap[3].unlockedSetV4( 3, 3, @@ -89,7 +120,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { rAddr := udp.NewAddrFromString("1.2.2.3:12345") rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList() + lh.addrMap[2] = NewRemoteList(nil) lh.addrMap[2].unlockedSetV4( 3, 3, @@ -162,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() @@ -238,11 +269,20 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + assert.NoError(t, err) + + nc := map[interface{}]interface{}{ + "static_host_map": map[interface{}]interface{}{ + "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + }, + } + rc, err := yaml.Marshal(nc) assert.NoError(t, err) + c.ReloadConfigString(string(rc)) - c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} - lh.reload(c, false) + err = lh.reload(c, false) + assert.NoError(t, err) } func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { @@ -372,11 +412,28 @@ type testEncWriter struct { metaFilter *NebulaMeta_MessageType } -func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { } +func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { + msg := &NebulaMeta{} + err := msg.Unmarshal(p) + if tw.metaFilter == nil || msg.Type == *tw.metaFilter { + tw.lastReply = testLhReply{ + nebType: t, + nebSubType: st, + vpnIp: hostinfo.vpnIp, + msg: msg, + } + } + + if err != nil { + panic(err) + } +} + func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) diff --git a/main.go b/main.go index 99fe72cc0..7a0a0cff3 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package nebula import ( "context" "encoding/binary" - "errors" "fmt" "net" "time" @@ -19,7 +18,7 @@ import ( type m map[string]interface{} -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg err := configLogger(l, c) if err != nil { - return nil, util.NewContextualError("Failed to configure the logger", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) } c.RegisterReloadCallback(func(c *config.C) { @@ -56,36 +55,31 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } }) - caPool, err := loadCAFromConfig(l, c) + pki, err := NewPKIFromConfig(l, c) if err != nil { - //The errors coming out of loadCA are already nicely formatted - return nil, util.NewContextualError("Failed to load ca from config", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") - cs, err := NewCertStateFromConfig(c) + certificate := pki.GetCertState().Certificate + fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { - //The errors coming out of NewCertStateFromConfig are already nicely formatted - return nil, util.NewContextualError("Failed to load certificate from config", nil, err) + return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - - fw, err := NewFirewallFromConfig(l, cs.certificate, c) - if err != nil { - return nil, util.NewContextualError("Error while loading firewall rules", nil, err) - } - l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") + l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") // TODO: make sure mask is 4 bytes - tunCidr := cs.certificate.Details.Ips[0] + tunCidr := certificate.Details.Ips[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + if err != nil { + return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) + } wireSSHReload(l, ssh, c) var sshStart func() if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, util.NewContextualError("Error while configuring the sshd", nil, err) + return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) } } @@ -134,9 +128,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { c.CatchHUP(ctx) - tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) + if deviceFactory == nil { + deviceFactory = overlay.NewDeviceFromConfig + } + + tun, err = deviceFactory(c, l, tunCidr, routines) if err != nil { - return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } defer func() { @@ -147,83 +145,49 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // set up our UDP listener - udpConns := make([]*udp.Conn, routines) + udpConns := make([]udp.Conn, routines) port := c.GetInt("listen.port", 0) if !configTest { - for i := 0; i < routines; i++ { - udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) + rawListenHost := c.GetString("listen.host", "0.0.0.0") + var listenHost *net.IPAddr + if rawListenHost == "[::]" { + // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. + listenHost = &net.IPAddr{IP: net.IPv6zero} + + } else { + listenHost, err = net.ResolveIPAddr("ip", rawListenHost) if err != nil { - return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } - udpServer.ReloadConfig(c) - udpConns[i] = udpServer } - } - // Set up my internal host map - var preferredRanges []*net.IPNet - rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) - // First, check if 'preferred_ranges' is set and fallback to 'local_range' - if len(rawPreferredRanges) > 0 { - for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + for i := 0; i < routines; i++ { + l.Infof("listening %q %d", listenHost.IP, port) + udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { - return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } - preferredRanges = append(preferredRanges, preferredRange) - } - } - - // local_range was superseded by preferred_ranges. If it is still present, - // merge the local_range setting into preferred_ranges. We will probably - // deprecate local_range and remove in the future. - rawLocalRange := c.GetString("local_range", "") - if rawLocalRange != "" { - _, localRange, err := net.ParseCIDR(rawLocalRange) - if err != nil { - return nil, util.NewContextualError("Failed to parse local_range", nil, err) - } + udpServer.ReloadConfig(c) + udpConns[i] = udpServer - // Check if the entry for local_range was already specified in - // preferred_ranges. Don't put it into the slice twice if so. - var found bool - for _, r := range preferredRanges { - if r.String() == localRange.String() { - found = true - break + // If port is dynamic, discover it before the next pass through the for loop + // This way all routines will use the same port correctly + if port == 0 { + uPort, err := udpServer.LocalAddr() + if err != nil { + return nil, util.NewContextualError("Failed to get listening port", nil, err) + } + port = int(uPort.Port) } } - if !found { - preferredRanges = append(preferredRanges, localRange) - } } - hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) - hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) - - l. - WithField("network", hostMap.vpnCIDR.String()). - WithField("preferredRanges", hostMap.preferredRanges). - Info("Main HostMap created") - - /* - config.SetDefault("promoter.interval", 10) - go hostMap.Promoter(config.GetInt("promoter.interval")) - */ - + hostMap := NewHostMapFromConfig(l, tunCidr, c) punchy := NewPunchyFromConfig(l, c) - if punchy.GetPunch() && !configTest { - l.Info("UDP hole punching enabled") - go hostMap.Punchy(ctx, udpConns[0]) - } - - lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) - switch { - case errors.As(err, &util.ContextualError{}): - return nil, err - case err != nil: - return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + if err != nil { + return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } var messageMetrics *MessageMetrics @@ -244,13 +208,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg messageMetrics: messageMetrics, } - handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) + handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - //TODO: These will be reused for psk - //handshakeMACKey := config.GetString("handshake_mac.key", "") - //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) - serveDns := false if c.GetBool("lighthouse.serve_dns", false) { if c.GetBool("lighthouse.am_lighthouse", false) { @@ -262,26 +222,29 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg checkInterval := c.GetInt("timers.connection_alive_interval", 5) pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10) + ifConfig := &InterfaceConfig{ HostMap: hostMap, Inside: tun, Outside: udpConns[0], - certState: cs, + pki: pki, Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, lightHouse: lightHouse, - checkInterval: checkInterval, - pendingDeletionInterval: pendingDeletionInterval, + checkInterval: time.Second * time.Duration(checkInterval), + pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval), + tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), + reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), + reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropMulticast: c.GetBool("tun.drop_multicast", false), routines: routines, MessageMetrics: messageMetrics, version: buildVersion, - caPool: caPool, - disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), + punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, @@ -306,21 +269,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // TODO: Better way to attach these, probably want a new interface in InterfaceConfig // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns + lightHouse.ifce = ifce ifce.RegisterConfigChangeCallbacks(c) - + ifce.reloadDisconnectInvalid(c) ifce.reloadSendRecvError(c) - go handshakeManager.Run(ctx, ifce) - go lightHouse.LhUpdateWorker(ctx, ifce) + handshakeManager.f = ifce + go handshakeManager.Run(ctx) } // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) - if err != nil { - return nil, util.NewContextualError("Failed to start stats emitter", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } if configTest { @@ -330,7 +293,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) - attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) + attachCommands(l, c, ssh, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() @@ -339,5 +302,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg dnsStart = dnsMain(l, hostMap, c) } - return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil + return &Control{ + ifce, + l, + ctx, + cancel, + sshStart, + statsStart, + dnsStart, + lightHouse.StartUpdateWorker, + }, nil } diff --git a/message_metrics.go b/message_metrics.go index b229cdf07..94bb02fe3 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -84,6 +84,7 @@ func newLighthouseMetrics() *MessageMetrics { NebulaMeta_HostQueryReply, NebulaMeta_HostUpdateNotification, NebulaMeta_HostPunchNotification, + NebulaMeta_HostUpdateNotificationAck, } for _, i := range used { h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)} diff --git a/nebula.pb.go b/nebula.pb.go index 649b7cb5e..b3c723a46 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -25,42 +25,45 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type NebulaMeta_MessageType int32 const ( - NebulaMeta_None NebulaMeta_MessageType = 0 - NebulaMeta_HostQuery NebulaMeta_MessageType = 1 - NebulaMeta_HostQueryReply NebulaMeta_MessageType = 2 - NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3 - NebulaMeta_HostMovedNotification NebulaMeta_MessageType = 4 - NebulaMeta_HostPunchNotification NebulaMeta_MessageType = 5 - NebulaMeta_HostWhoami NebulaMeta_MessageType = 6 - NebulaMeta_HostWhoamiReply NebulaMeta_MessageType = 7 - NebulaMeta_PathCheck NebulaMeta_MessageType = 8 - NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9 + NebulaMeta_None NebulaMeta_MessageType = 0 + NebulaMeta_HostQuery NebulaMeta_MessageType = 1 + NebulaMeta_HostQueryReply NebulaMeta_MessageType = 2 + NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3 + NebulaMeta_HostMovedNotification NebulaMeta_MessageType = 4 + NebulaMeta_HostPunchNotification NebulaMeta_MessageType = 5 + NebulaMeta_HostWhoami NebulaMeta_MessageType = 6 + NebulaMeta_HostWhoamiReply NebulaMeta_MessageType = 7 + NebulaMeta_PathCheck NebulaMeta_MessageType = 8 + NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9 + NebulaMeta_HostUpdateNotificationAck NebulaMeta_MessageType = 10 ) var NebulaMeta_MessageType_name = map[int32]string{ - 0: "None", - 1: "HostQuery", - 2: "HostQueryReply", - 3: "HostUpdateNotification", - 4: "HostMovedNotification", - 5: "HostPunchNotification", - 6: "HostWhoami", - 7: "HostWhoamiReply", - 8: "PathCheck", - 9: "PathCheckReply", + 0: "None", + 1: "HostQuery", + 2: "HostQueryReply", + 3: "HostUpdateNotification", + 4: "HostMovedNotification", + 5: "HostPunchNotification", + 6: "HostWhoami", + 7: "HostWhoamiReply", + 8: "PathCheck", + 9: "PathCheckReply", + 10: "HostUpdateNotificationAck", } var NebulaMeta_MessageType_value = map[string]int32{ - "None": 0, - "HostQuery": 1, - "HostQueryReply": 2, - "HostUpdateNotification": 3, - "HostMovedNotification": 4, - "HostPunchNotification": 5, - "HostWhoami": 6, - "HostWhoamiReply": 7, - "PathCheck": 8, - "PathCheckReply": 9, + "None": 0, + "HostQuery": 1, + "HostQueryReply": 2, + "HostUpdateNotification": 3, + "HostMovedNotification": 4, + "HostPunchNotification": 5, + "HostWhoami": 6, + "HostWhoamiReply": 7, + "PathCheck": 8, + "PathCheckReply": 9, + "HostUpdateNotificationAck": 10, } func (x NebulaMeta_MessageType) String() string { @@ -637,51 +640,52 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 696 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0xcd, 0x6e, 0xd3, 0x4a, - 0x14, 0x8e, 0x1d, 0xe7, 0xef, 0xa4, 0x49, 0x7d, 0x4f, 0xef, 0xcd, 0x4d, 0xaf, 0xae, 0xac, 0xe0, - 0x05, 0xca, 0x2a, 0xad, 0xd2, 0x52, 0xb1, 0x04, 0x82, 0x50, 0x52, 0xb5, 0x55, 0x18, 0x15, 0x90, - 0xd8, 0xa0, 0x69, 0x32, 0xd4, 0x56, 0x12, 0x8f, 0x6b, 0x4f, 0x50, 0xf3, 0x16, 0x3c, 0x4c, 0x1f, - 0x82, 0x05, 0x12, 0x5d, 0xb0, 0x60, 0x89, 0xda, 0x17, 0x41, 0x33, 0x76, 0x6c, 0x27, 0x0d, 0xec, - 0xce, 0xcf, 0xf7, 0xcd, 0x7c, 0xe7, 0x9b, 0x63, 0xc3, 0x96, 0xc7, 0x2e, 0xe6, 0x53, 0xda, 0xf1, - 0x03, 0x2e, 0x38, 0x16, 0xa3, 0xcc, 0xfe, 0xaa, 0x03, 0x9c, 0xa9, 0xf0, 0x94, 0x09, 0x8a, 0x5d, - 0x30, 0xce, 0x17, 0x3e, 0x6b, 0x6a, 0x2d, 0xad, 0x5d, 0xef, 0x5a, 0x9d, 0x98, 0x93, 0x22, 0x3a, - 0xa7, 0x2c, 0x0c, 0xe9, 0x25, 0x93, 0x28, 0xa2, 0xb0, 0x78, 0x00, 0xa5, 0x97, 0x4c, 0x50, 0x77, - 0x1a, 0x36, 0xf5, 0x96, 0xd6, 0xae, 0x76, 0x77, 0x1f, 0xd2, 0x62, 0x00, 0x59, 0x22, 0xed, 0xef, - 0x1a, 0x54, 0x33, 0x47, 0x61, 0x19, 0x8c, 0x33, 0xee, 0x31, 0x33, 0x87, 0x35, 0xa8, 0xf4, 0x79, - 0x28, 0x5e, 0xcf, 0x59, 0xb0, 0x30, 0x35, 0x44, 0xa8, 0x27, 0x29, 0x61, 0xfe, 0x74, 0x61, 0xea, - 0xf8, 0x1f, 0x34, 0x64, 0xed, 0x8d, 0x3f, 0xa6, 0x82, 0x9d, 0x71, 0xe1, 0x7e, 0x74, 0x47, 0x54, - 0xb8, 0xdc, 0x33, 0xf3, 0xb8, 0x0b, 0xff, 0xc8, 0xde, 0x29, 0xff, 0xc4, 0xc6, 0x2b, 0x2d, 0x63, - 0xd9, 0x1a, 0xce, 0xbd, 0x91, 0xb3, 0xd2, 0x2a, 0x60, 0x1d, 0x40, 0xb6, 0xde, 0x39, 0x9c, 0xce, - 0x5c, 0xb3, 0x88, 0x3b, 0xb0, 0x9d, 0xe6, 0xd1, 0xb5, 0x25, 0xa9, 0x6c, 0x48, 0x85, 0xd3, 0x73, - 0xd8, 0x68, 0x62, 0x96, 0xa5, 0xb2, 0x24, 0x8d, 0x20, 0x15, 0xfb, 0x9b, 0x06, 0x7f, 0x3d, 0x98, - 0x1a, 0xff, 0x86, 0xc2, 0x5b, 0xdf, 0x1b, 0xf8, 0xca, 0xd6, 0x1a, 0x89, 0x12, 0x3c, 0x84, 0xea, - 0xc0, 0x3f, 0x7c, 0xee, 0x8d, 0x87, 0x3c, 0x10, 0xd2, 0xbb, 0x7c, 0xbb, 0xda, 0xc5, 0xa5, 0x77, - 0x69, 0x8b, 0x64, 0x61, 0x11, 0xeb, 0x28, 0x61, 0x19, 0xeb, 0xac, 0xa3, 0x0c, 0x2b, 0x81, 0xa1, - 0x05, 0x40, 0xd8, 0x94, 0x2e, 0x22, 0x19, 0x85, 0x56, 0xbe, 0x5d, 0x23, 0x99, 0x0a, 0x36, 0xa1, - 0x34, 0xe2, 0x73, 0x4f, 0xb0, 0xa0, 0x99, 0x57, 0x1a, 0x97, 0xa9, 0xbd, 0x0f, 0x90, 0x5e, 0x8f, - 0x75, 0xd0, 0x93, 0x31, 0xf4, 0x81, 0x8f, 0x08, 0x86, 0xac, 0xab, 0x87, 0xaf, 0x11, 0x15, 0xdb, - 0xcf, 0x24, 0xe3, 0x28, 0xc3, 0xe8, 0xbb, 0x8a, 0x61, 0x10, 0xbd, 0xef, 0xca, 0xfc, 0x84, 0x2b, - 0xbc, 0x41, 0xf4, 0x13, 0x9e, 0x9c, 0x90, 0xcf, 0x9c, 0x70, 0xbd, 0xdc, 0xc9, 0xa1, 0xeb, 0x5d, - 0xfe, 0x79, 0x27, 0x25, 0x62, 0xc3, 0x4e, 0x22, 0x18, 0xe7, 0xee, 0x8c, 0xc5, 0xf7, 0xa8, 0xd8, - 0xb6, 0x1f, 0x6c, 0x9c, 0x24, 0x9b, 0x39, 0xac, 0x40, 0x21, 0x7a, 0x3f, 0xcd, 0xfe, 0x00, 0xdb, - 0xd1, 0xb9, 0x7d, 0xea, 0x8d, 0x43, 0x87, 0x4e, 0x18, 0x3e, 0x4d, 0xd7, 0x5b, 0x53, 0xeb, 0xbd, - 0xa6, 0x20, 0x41, 0xae, 0xef, 0xb8, 0x14, 0xd1, 0x9f, 0xd1, 0x91, 0x12, 0xb1, 0x45, 0x54, 0x6c, - 0xdf, 0x68, 0xd0, 0xd8, 0xcc, 0x93, 0xf0, 0x1e, 0x0b, 0x84, 0xba, 0x65, 0x8b, 0xa8, 0x18, 0x1f, - 0x43, 0x7d, 0xe0, 0xb9, 0xc2, 0xa5, 0x82, 0x07, 0x03, 0x6f, 0xcc, 0xae, 0x63, 0xa7, 0xd7, 0xaa, - 0x12, 0x47, 0x58, 0xe8, 0x73, 0x6f, 0xcc, 0x62, 0x5c, 0xe4, 0xe7, 0x5a, 0x15, 0x1b, 0x50, 0xec, - 0x71, 0x3e, 0x71, 0x59, 0xd3, 0x50, 0xce, 0xc4, 0x59, 0xe2, 0x57, 0x21, 0xf5, 0xeb, 0xd8, 0x28, - 0x17, 0xcd, 0xd2, 0xb1, 0x51, 0x2e, 0x99, 0x65, 0xfb, 0x46, 0x87, 0x5a, 0x24, 0xbb, 0xc7, 0x3d, - 0x11, 0xf0, 0x29, 0x3e, 0x59, 0x79, 0x95, 0x47, 0xab, 0x9e, 0xc4, 0xa0, 0x0d, 0x0f, 0xb3, 0x0f, - 0x3b, 0x89, 0x74, 0xb5, 0x7f, 0xd9, 0xa9, 0x36, 0xb5, 0x24, 0x23, 0x19, 0x22, 0xc3, 0x88, 0xe6, - 0xdb, 0xd4, 0xc2, 0xff, 0xa1, 0xa2, 0xb2, 0x73, 0x3e, 0xf0, 0xd5, 0x9c, 0x35, 0x92, 0x16, 0xb0, - 0x05, 0x55, 0x95, 0xbc, 0x0a, 0xf8, 0x4c, 0x7d, 0x0b, 0xb2, 0x9f, 0x2d, 0xd9, 0xfd, 0xdf, 0xfd, - 0x9a, 0x1a, 0x80, 0xbd, 0x80, 0x51, 0xc1, 0x14, 0x9a, 0xb0, 0xab, 0x39, 0x0b, 0x85, 0xa9, 0xe1, - 0xbf, 0xb0, 0xb3, 0x52, 0x97, 0x92, 0x42, 0x66, 0xea, 0x2f, 0x0e, 0xbe, 0xdc, 0x59, 0xda, 0xed, - 0x9d, 0xa5, 0xfd, 0xbc, 0xb3, 0xb4, 0xcf, 0xf7, 0x56, 0xee, 0xf6, 0xde, 0xca, 0xfd, 0xb8, 0xb7, - 0x72, 0xef, 0x77, 0x2f, 0x5d, 0xe1, 0xcc, 0x2f, 0x3a, 0x23, 0x3e, 0xdb, 0x0b, 0xa7, 0x74, 0x34, - 0x71, 0xae, 0xf6, 0x22, 0x0b, 0x2f, 0x8a, 0xea, 0x0f, 0x7d, 0xf0, 0x2b, 0x00, 0x00, 0xff, 0xff, - 0xcd, 0xd7, 0xbe, 0xd5, 0xb1, 0x05, 0x00, 0x00, + // 707 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0x4d, 0x6f, 0xda, 0x4a, + 0x14, 0xc5, 0xc6, 0x7c, 0x5d, 0x02, 0xf1, 0xbb, 0x79, 0x8f, 0x07, 0x4f, 0xaf, 0x16, 0xf5, 0xa2, + 0x62, 0x45, 0x22, 0x92, 0x46, 0x5d, 0x36, 0xa5, 0xaa, 0x20, 0x4a, 0x22, 0x3a, 0x4a, 0x5b, 0xa9, + 0x9b, 0x6a, 0x62, 0xa6, 0xc1, 0x02, 0x3c, 0x8e, 0x3d, 0x54, 0xe1, 0x5f, 0xf4, 0xc7, 0xe4, 0x47, + 0x74, 0xd7, 0x2c, 0xbb, 0xac, 0x92, 0x65, 0x97, 0xfd, 0x03, 0xd5, 0x8c, 0xc1, 0x36, 0x84, 0x76, + 0x37, 0xe7, 0xde, 0x73, 0x66, 0xce, 0x9c, 0xb9, 0x36, 0x6c, 0x79, 0xec, 0x62, 0x36, 0xa1, 0x6d, + 0x3f, 0xe0, 0x82, 0x63, 0x3e, 0x42, 0xf6, 0x0f, 0x1d, 0xe0, 0x4c, 0x2d, 0x4f, 0x99, 0xa0, 0xd8, + 0x01, 0xe3, 0x7c, 0xee, 0xb3, 0xba, 0xd6, 0xd4, 0x5a, 0xd5, 0x8e, 0xd5, 0x5e, 0x68, 0x12, 0x46, + 0xfb, 0x94, 0x85, 0x21, 0xbd, 0x64, 0x92, 0x45, 0x14, 0x17, 0xf7, 0xa1, 0xf0, 0x92, 0x09, 0xea, + 0x4e, 0xc2, 0xba, 0xde, 0xd4, 0x5a, 0xe5, 0x4e, 0xe3, 0xa1, 0x6c, 0x41, 0x20, 0x4b, 0xa6, 0xfd, + 0x53, 0x83, 0x72, 0x6a, 0x2b, 0x2c, 0x82, 0x71, 0xc6, 0x3d, 0x66, 0x66, 0xb0, 0x02, 0xa5, 0x1e, + 0x0f, 0xc5, 0xeb, 0x19, 0x0b, 0xe6, 0xa6, 0x86, 0x08, 0xd5, 0x18, 0x12, 0xe6, 0x4f, 0xe6, 0xa6, + 0x8e, 0xff, 0x41, 0x4d, 0xd6, 0xde, 0xf8, 0x43, 0x2a, 0xd8, 0x19, 0x17, 0xee, 0x47, 0xd7, 0xa1, + 0xc2, 0xe5, 0x9e, 0x99, 0xc5, 0x06, 0xfc, 0x23, 0x7b, 0xa7, 0xfc, 0x13, 0x1b, 0xae, 0xb4, 0x8c, + 0x65, 0x6b, 0x30, 0xf3, 0x9c, 0xd1, 0x4a, 0x2b, 0x87, 0x55, 0x00, 0xd9, 0x7a, 0x37, 0xe2, 0x74, + 0xea, 0x9a, 0x79, 0xdc, 0x81, 0xed, 0x04, 0x47, 0xc7, 0x16, 0xa4, 0xb3, 0x01, 0x15, 0xa3, 0xee, + 0x88, 0x39, 0x63, 0xb3, 0x28, 0x9d, 0xc5, 0x30, 0xa2, 0x94, 0xf0, 0x11, 0x34, 0x36, 0x3b, 0x3b, + 0x72, 0xc6, 0x26, 0xd8, 0x5f, 0x35, 0xf8, 0xeb, 0x41, 0x28, 0xf8, 0x37, 0xe4, 0xde, 0xfa, 0x5e, + 0xdf, 0x57, 0xa9, 0x57, 0x48, 0x04, 0xf0, 0x00, 0xca, 0x7d, 0xff, 0xe0, 0xc8, 0x1b, 0x0e, 0x78, + 0x20, 0x64, 0xb4, 0xd9, 0x56, 0xb9, 0x83, 0xcb, 0x68, 0x93, 0x16, 0x49, 0xd3, 0x22, 0xd5, 0x61, + 0xac, 0x32, 0xd6, 0x55, 0x87, 0x29, 0x55, 0x4c, 0x43, 0x0b, 0x80, 0xb0, 0x09, 0x9d, 0x47, 0x36, + 0x72, 0xcd, 0x6c, 0xab, 0x42, 0x52, 0x15, 0xac, 0x43, 0xc1, 0xe1, 0x33, 0x4f, 0xb0, 0xa0, 0x9e, + 0x55, 0x1e, 0x97, 0xd0, 0xde, 0x03, 0x48, 0x8e, 0xc7, 0x2a, 0xe8, 0xf1, 0x35, 0xf4, 0xbe, 0x8f, + 0x08, 0x86, 0xac, 0xab, 0xb9, 0xa8, 0x10, 0xb5, 0xb6, 0x9f, 0x4b, 0xc5, 0x61, 0x4a, 0xd1, 0x73, + 0x95, 0xc2, 0x20, 0x7a, 0xcf, 0x95, 0xf8, 0x84, 0x2b, 0xbe, 0x41, 0xf4, 0x13, 0x1e, 0xef, 0x90, + 0x4d, 0xed, 0x70, 0xbd, 0x1c, 0xd9, 0x81, 0xeb, 0x5d, 0xfe, 0x79, 0x64, 0x25, 0x63, 0xc3, 0xc8, + 0x22, 0x18, 0xe7, 0xee, 0x94, 0x2d, 0xce, 0x51, 0x6b, 0xdb, 0x7e, 0x30, 0x90, 0x52, 0x6c, 0x66, + 0xb0, 0x04, 0xb9, 0xe8, 0x79, 0x35, 0xfb, 0x03, 0x6c, 0x47, 0xfb, 0xf6, 0xa8, 0x37, 0x0c, 0x47, + 0x74, 0xcc, 0xf0, 0x59, 0x32, 0xfd, 0x9a, 0x9a, 0xfe, 0x35, 0x07, 0x31, 0x73, 0xfd, 0x13, 0x90, + 0x26, 0x7a, 0x53, 0xea, 0x28, 0x13, 0x5b, 0x44, 0xad, 0xed, 0x1b, 0x0d, 0x6a, 0x9b, 0x75, 0x92, + 0xde, 0x65, 0x81, 0x50, 0xa7, 0x6c, 0x11, 0xb5, 0xc6, 0x27, 0x50, 0xed, 0x7b, 0xae, 0x70, 0xa9, + 0xe0, 0x41, 0xdf, 0x1b, 0xb2, 0xeb, 0x45, 0xd2, 0x6b, 0x55, 0xc9, 0x23, 0x2c, 0xf4, 0xb9, 0x37, + 0x64, 0x0b, 0x5e, 0x94, 0xe7, 0x5a, 0x15, 0x6b, 0x90, 0xef, 0x72, 0x3e, 0x76, 0x59, 0xdd, 0x50, + 0xc9, 0x2c, 0x50, 0x9c, 0x57, 0x2e, 0xc9, 0xeb, 0xd8, 0x28, 0xe6, 0xcd, 0xc2, 0xb1, 0x51, 0x2c, + 0x98, 0x45, 0xfb, 0x46, 0x87, 0x4a, 0x64, 0xbb, 0xcb, 0x3d, 0x11, 0xf0, 0x09, 0x3e, 0x5d, 0x79, + 0x95, 0xc7, 0xab, 0x99, 0x2c, 0x48, 0x1b, 0x1e, 0x66, 0x0f, 0x76, 0x62, 0xeb, 0x6a, 0xfe, 0xd2, + 0xb7, 0xda, 0xd4, 0x92, 0x8a, 0xf8, 0x12, 0x29, 0x45, 0x74, 0xbf, 0x4d, 0x2d, 0xfc, 0x1f, 0x4a, + 0x0a, 0x9d, 0xf3, 0xbe, 0xaf, 0xee, 0x59, 0x21, 0x49, 0x01, 0x9b, 0x50, 0x56, 0xe0, 0x55, 0xc0, + 0xa7, 0xea, 0x5b, 0x90, 0xfd, 0x74, 0xc9, 0xee, 0xfd, 0xee, 0xcf, 0x55, 0x03, 0xec, 0x06, 0x8c, + 0x0a, 0xa6, 0xd8, 0x84, 0x5d, 0xcd, 0x58, 0x28, 0x4c, 0x0d, 0xff, 0x85, 0x9d, 0x95, 0xba, 0xb4, + 0x14, 0x32, 0x53, 0x7f, 0xb1, 0xff, 0xe5, 0xce, 0xd2, 0x6e, 0xef, 0x2c, 0xed, 0xfb, 0x9d, 0xa5, + 0x7d, 0xbe, 0xb7, 0x32, 0xb7, 0xf7, 0x56, 0xe6, 0xdb, 0xbd, 0x95, 0x79, 0xdf, 0xb8, 0x74, 0xc5, + 0x68, 0x76, 0xd1, 0x76, 0xf8, 0x74, 0x37, 0x9c, 0x50, 0x67, 0x3c, 0xba, 0xda, 0x8d, 0x22, 0xbc, + 0xc8, 0xab, 0x1f, 0xf8, 0xfe, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x17, 0x56, 0x28, 0x74, 0xd0, + 0x05, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { diff --git a/nebula.proto b/nebula.proto index 5e839be73..88e33b7e9 100644 --- a/nebula.proto +++ b/nebula.proto @@ -15,6 +15,7 @@ message NebulaMeta { HostWhoamiReply = 7; PathCheck = 8; PathCheckReply = 9; + HostUpdateNotificationAck = 10; } MessageType Type = 1; diff --git a/noiseutil/boring.go b/noiseutil/boring.go new file mode 100644 index 000000000..e9ad19bb6 --- /dev/null +++ b/noiseutil/boring.go @@ -0,0 +1,80 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + + // unsafe needed for go:linkname + _ "unsafe" + + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +// This is true for boringcrypto because the Seal function verifies that the +// nonce is strictly increasing. +const EncryptLockNeeded = true + +// NewGCMTLS is no longer exposed in go1.19+, so we need to link it in +// See: https://github.com/golang/go/issues/56326 +// +// NewGCMTLS is the internal method used with boringcrypto that provices a +// validated mode of AES-GCM which enforces the nonce is strictly +// monotonically increasing. This is the TLS 1.2 specification for nonce +// generation (which also matches the method used by the Noise Protocol) +// +// - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093 +// +//go:linkname newGCMTLS crypto/internal/boring.NewGCMTLS +func newGCMTLS(c cipher.Block) (cipher.AEAD, error) + +type cipherFn struct { + fn func([32]byte) noise.Cipher + name string +} + +func (c cipherFn) Cipher(k [32]byte) noise.Cipher { return c.fn(k) } +func (c cipherFn) CipherName() string { return c.name } + +// CipherAESGCM is the AES256-GCM AEAD cipher (using NewGCMTLS when GoBoring is present) +var CipherAESGCM noise.CipherFunc = cipherFn{cipherAESGCMBoring, "AESGCM"} + +func cipherAESGCMBoring(k [32]byte) noise.Cipher { + c, err := aes.NewCipher(k[:]) + if err != nil { + panic(err) + } + gcm, err := newGCMTLS(c) + if err != nil { + panic(err) + } + return aeadCipher{ + gcm, + func(n uint64) []byte { + var nonce [12]byte + binary.BigEndian.PutUint64(nonce[4:], n) + return nonce[:] + }, + } +} + +type aeadCipher struct { + cipher.AEAD + nonce func(uint64) []byte +} + +func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte { + return c.Seal(out, c.nonce(n), plaintext, ad) +} + +func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) { + return c.Open(out, c.nonce(n), ciphertext, ad) +} diff --git a/noiseutil/boring_test.go b/noiseutil/boring_test.go new file mode 100644 index 000000000..8c8843924 --- /dev/null +++ b/noiseutil/boring_test.go @@ -0,0 +1,46 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "crypto/boring" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptLockNeeded(t *testing.T) { + assert.True(t, EncryptLockNeeded) +} + +// Ensure NewGCMTLS validates the nonce is non-repeating +func TestNewGCMTLS(t *testing.T) { + assert.True(t, boring.Enabled()) + + // Test Case 16 from GCM Spec: + // - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf + // - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418 + key, _ := hex.DecodeString("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308") + iv, _ := hex.DecodeString("cafebabefacedbaddecaf888") + plaintext, _ := hex.DecodeString("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39") + aad, _ := hex.DecodeString("feedfacedeadbeeffeedfacedeadbeefabaddad2") + expected, _ := hex.DecodeString("522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662") + expectedTag, _ := hex.DecodeString("76fc6ece0f4e1768cddf8853bb2d551b") + + expected = append(expected, expectedTag...) + + var keyArray [32]byte + copy(keyArray[:], key) + c := CipherAESGCM.Cipher(keyArray) + aead := c.(aeadCipher).AEAD + + dst := aead.Seal([]byte{}, iv, plaintext, aad) + assert.Equal(t, expected, dst) + + // We expect this to fail since we are re-encrypting with a repeat IV + assert.PanicsWithError(t, "boringcrypto: EVP_AEAD_CTX_seal failed", func() { + dst = aead.Seal([]byte{}, iv, plaintext, aad) + }) +} diff --git a/noiseutil/nist.go b/noiseutil/nist.go new file mode 100644 index 000000000..90e77abc7 --- /dev/null +++ b/noiseutil/nist.go @@ -0,0 +1,68 @@ +package noiseutil + +import ( + "crypto/ecdh" + "crypto/rand" + "fmt" + "io" + + "github.com/flynn/noise" +) + +// DHP256 is the NIST P-256 ECDH function +var DHP256 noise.DHFunc = newNISTCurve("P256", ecdh.P256(), 32) + +type nistCurve struct { + name string + curve ecdh.Curve + dhLen int + pubLen int +} + +func newNISTCurve(name string, curve ecdh.Curve, byteLen int) nistCurve { + return nistCurve{ + name: name, + curve: curve, + dhLen: byteLen, + // Standard uncompressed format, type (1 byte) plus both coordinates + pubLen: 1 + 2*byteLen, + } +} + +func (c nistCurve) GenerateKeypair(rng io.Reader) (noise.DHKey, error) { + if rng == nil { + rng = rand.Reader + } + privkey, err := c.curve.GenerateKey(rng) + if err != nil { + return noise.DHKey{}, err + } + pubkey := privkey.PublicKey() + return noise.DHKey{Private: privkey.Bytes(), Public: pubkey.Bytes()}, nil +} + +func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) { + ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + } + ecdhPrivKey, err := c.curve.NewPrivateKey(privkey) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + } + + return ecdhPrivKey.ECDH(ecdhPubKey) +} + +func (c nistCurve) DHLen() int { + // NOTE: Noise Protocol specifies "DHLen" to represent two things: + // - The size of the public key + // - The return size of the DH() function + // But for standard NIST ECDH, the sizes of these are different. + // Luckily, the flynn/noise library actually only uses this DHLen() + // value to represent the public key size, so that is what we are + // returning here. The length of the DH() return bytes are unaffected by + // this value here. + return c.pubLen +} +func (c nistCurve) DHName() string { return c.name } diff --git a/noiseutil/notboring.go b/noiseutil/notboring.go new file mode 100644 index 000000000..be746f409 --- /dev/null +++ b/noiseutil/notboring.go @@ -0,0 +1,14 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +const EncryptLockNeeded = false + +// CipherAESGCM is the standard noise.CipherAESGCM when boringcrypto is not enabled +var CipherAESGCM noise.CipherFunc = noise.CipherAESGCM diff --git a/noiseutil/notboring_test.go b/noiseutil/notboring_test.go new file mode 100644 index 000000000..b865391e7 --- /dev/null +++ b/noiseutil/notboring_test.go @@ -0,0 +1,14 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptLockNeeded(t *testing.T) { + assert.False(t, EncryptLockNeeded) +} diff --git a/notboring.go b/notboring.go new file mode 100644 index 000000000..c86b0bc3e --- /dev/null +++ b/notboring.go @@ -0,0 +1,6 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package nebula + +var boringEnabled = func() bool { return false } diff --git a/outside.go b/outside.go index c43a385d3..818e2ae4b 100644 --- a/outside.go +++ b/outside.go @@ -21,7 +21,23 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func readOutsidePackets(f *Interface) udp.EncReader { + return func( + addr *udp.Addr, + out []byte, + packet []byte, + header *header.H, + fwPacket *firewall.Packet, + lhh udp.LightHouseHandlerFunc, + nb []byte, + q int, + localCache firewall.ConntrackCache, + ) { + f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) + } +} + +func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -48,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by var hostinfo *HostInfo // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { - hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { - hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } var ci *ConnectionState @@ -67,7 +83,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by switch h.Subtype { case header.MessageNone: - f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) + if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { + return + } case header.MessageRelay: // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. @@ -84,17 +102,15 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. f.handleHostRoaming(hostinfo, addr) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.RelayUsed(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This shouldn't happen! - hostinfo.logger(f.l).WithField("hostinfo", hostinfo.vpnIp).WithField("remoteIndex", h.RemoteIndex).Errorf("HostInfo missing remote index") - // Delete my local index from the hostmap - f.hostMap.DeleteRelayIdx(h.RemoteIndex) - // When the peer doesn't receive any return traffic, its connection_manager will eventually clean up - // the broken relay when it cleans up the associated HostInfo object. + // its internal mapping. This should never happen. + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -106,15 +122,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by return case ForwardingType: // Find the target HostInfo relay object - targetHI, err := f.hostMap.QueryVpnIp(relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) if err != nil { - hostinfo.logger(f.l).WithField("peerIp", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") - return - } - // find the target Relay info object - targetRelay, ok := targetHI.relayState.QueryRelayForByIp(hostinfo.vpnIp) - if !ok { - hostinfo.logger(f.l).WithField("peerIp", relay.PeerIp).Info("Failed to find relay in hostinfo") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") return } @@ -130,7 +140,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithField("targetRelayState", targetRelay.State).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -153,7 +163,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by return } - lhf(addr, hostinfo.vpnIp, d, f) + lhf(addr, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic @@ -188,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - HandleIncomingHandshake(f, addr, via, packet, h, hostinfo) + f.handshakeManager.HandleIncoming(addr, via, packet, h) return case header.RecvError: @@ -242,12 +252,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { - //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately - f.connectionManager.ClearLocalIndex(hostInfo.localIndexId) - f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId) - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) - - f.hostMap.DeleteHostInfo(hostInfo) + final := f.hostMap.DeleteHostInfo(hostInfo) + if final { + // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + } } // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote @@ -371,7 +380,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) @@ -379,30 +388,33 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") //TODO: maybe after build 64 is out? 06/14/2018 - NB //f.sendRecvError(hostinfo.remote, header.RemoteIndex) - return + return false } err = newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") - return + return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). Debugln("dropping out of window packet") - return + return false } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { + // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore + // This gives us a buffer to build the reject packet in + f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). Debugln("dropping inbound packet") } - return + return false } f.connectionManager.In(hostinfo.localIndexId) @@ -410,6 +422,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out if err != nil { f.l.WithError(err).Error("Failed to write to tun") } + return true } func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { @@ -438,29 +451,23 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { Debug("Recv error received") } - // First, clean up in the pending hostmap - f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex) - - hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) - if err != nil { - f.l.Debugln(err, ": ", h.RemoteIndex) + hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) + if hostinfo == nil { + f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") return } - hostinfo.Lock() - defer hostinfo.Unlock() - if !hostinfo.RecvErrorExceeded() { return } + if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } f.closeTunnel(hostinfo) - // We also delete it from pending hostmap to allow for - // fast reconnect. + // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } diff --git a/overlay/route.go b/overlay/route.go index e8626bb04..64c624c7e 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -1,6 +1,7 @@ package overlay import ( + "bytes" "fmt" "math" "net" @@ -14,14 +15,44 @@ import ( ) type Route struct { - MTU int - Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + MTU int + Metric int + Cidr *net.IPNet + Via *iputil.VpnIp + Install bool } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) { - routeTree := cidr.NewTree4() +// Equal determines if a route that could be installed in the system route table is equal to another +// Via is ignored since that is only consumed within nebula itself +func (r Route) Equal(t Route) bool { + if !r.Cidr.IP.Equal(t.Cidr.IP) { + return false + } + if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) { + return false + } + if r.Metric != t.Metric { + return false + } + if r.MTU != t.MTU { + return false + } + if r.Install != t.Install { + return false + } + return true +} + +func (r Route) String() string { + s := r.Cidr.String() + if r.Metric != 0 { + s += fmt.Sprintf(" metric: %v", r.Metric) + } + return s +} + +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { + routeTree := cidr.NewTree4[iputil.VpnIp]() for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) @@ -81,7 +112,8 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - MTU: mtu, + Install: true, + MTU: mtu, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) @@ -182,10 +214,20 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { viaVpnIp := iputil.Ip2VpnIp(nVia) + install := true + rInstall, ok := m["install"] + if ok { + install, err = strconv.ParseBool(fmt.Sprintf("%v", rInstall)) + if err != nil { + return nil, fmt.Errorf("entry %v.install in tun.unsafe_routes is not a boolean: %v", i+1, err) + } + } + r := Route{ - Via: &viaVpnIp, - MTU: mtu, - Metric: metric, + Via: &viaVpnIp, + MTU: mtu, + Metric: metric, + Install: install, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) diff --git a/overlay/route_test.go b/overlay/route_test.go index 1d4286d0e..46fb87ceb 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -92,6 +92,8 @@ func Test_parseRoutes(t *testing.T) { tested := 0 for _, r := range routes { + assert.True(t, r.Install) + if r.MTU == 8000 { assert.Equal(t, "10.0.0.1/32", r.Cidr.String()) tested++ @@ -205,35 +207,45 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + // bad install + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + routes, err = parseUnsafeRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, err) - assert.Len(t, routes, 3) + assert.Len(t, routes, 4) tested := 0 for _, r := range routes { if r.MTU == 8000 { assert.Equal(t, "1.0.0.1/32", r.Cidr.String()) + assert.False(t, r.Install) tested++ } else if r.MTU == 9000 { assert.Equal(t, 9000, r.MTU) assert.Equal(t, "1.0.0.0/29", r.Cidr.String()) + assert.True(t, r.Install) tested++ } else { assert.Equal(t, 1500, r.MTU) assert.Equal(t, 1234, r.Metric) assert.Equal(t, "1.0.0.2/32", r.Cidr.String()) + assert.True(t, r.Install) tested++ } } - if tested != 3 { - t.Fatal("Did not see both unsafe_routes") + if tested != 4 { + t.Fatal("Did not see all unsafe_routes") } } @@ -253,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) { assert.NoError(t, err) ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - r := routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) + ok, r := routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) + ok, r = routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.Nil(t, r) + ok, r = routeTree.MostSpecificContains(ip) + assert.False(t, ok) } diff --git a/overlay/tun.go b/overlay/tun.go index 3da50b8e5..cedd7fe76 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -10,42 +10,63 @@ import ( const DefaultMTU = 1300 -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) { - routes, err := parseRoutes(c, tunCidr) +// TODO: We may be able to remove routines +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) + +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + switch { + case c.GetBool("tun.disabled", false): + tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + return tun, nil + + default: + return newTun(c, l, tunCidr, routines > 1) + } +} + +func NewFdDeviceFromConfig(fd *int) DeviceFactory { + return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, tunCidr) + } +} + +func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { + if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { + return false, nil, nil + } + + routes, err := parseRoutes(c, cidr) if err != nil { - return nil, util.NewContextualError("Could not parse tun.routes", nil, err) + return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) + unsafeRoutes, err := parseUnsafeRoutes(c, cidr) if err != nil { - return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } - routes = append(routes, unsafeRoutes...) - switch { - case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) - return tun, nil + routes = append(routes, unsafeRoutes...) + return true, routes, nil +} - case fd != nil: - return newTunFromFd( - l, - *fd, - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - ) +// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table. +// Via is not used to evaluate since it does not affect the system route table. +func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route { + var removed []Route + has := func(entry Route) bool { + for _, check := range newRoutes { + if check.Equal(entry) { + return true + } + } + return false + } - default: - return newTun( - l, - c.GetString("tun.dev", ""), - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - routines > 1, - ) + for _, oldEntry := range oldRoutes { + if !has(oldEntry) { + removed = append(removed, oldEntry) + } } + + return removed } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 321aec848..c15827fe6 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -8,56 +8,85 @@ import ( "io" "net" "os" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int cidr *net.IPNet - routeTree *cidr.Tree4 + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] l *logrus.Logger } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: deviceFd, cidr: cidr, l: l, - routeTree: routeTree, - }, nil + } + + err := t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t tun) Activate() error { return nil } +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes + t.Routes.Store(&routes) + t.routeTree.Store(routeTree) + return nil +} + func (t *tun) Cidr() *net.IPNet { return t.cidr } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index d7b488439..1c6382827 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -9,12 +9,15 @@ import ( "io" "net" "os" + "sync/atomic" "syscall" "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) @@ -24,8 +27,9 @@ type tun struct { Device string cidr *net.IPNet DefaultMTU int - Routes []Route - routeTree *cidr.Tree4 + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + linkAddr *netroute.LinkAddr l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -47,14 +51,6 @@ type ifReq struct { pad [8]byte } -func ioctl(a1, a2, a3 uintptr) error { - _, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3) - if errno != 0 { - return errno - } - return nil -} - var sockaddrCtlSize uintptr = 32 const ( @@ -77,12 +73,8 @@ type ifreqMTU struct { pad [8]byte } -func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { + name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { _, err := fmt.Sscanf(name, "utun%d", &ifIndex) @@ -150,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout file := os.NewFile(uintptr(fd), "") - tun := &tun{ + t := &tun{ ReadWriteCloser: file, Device: name, cidr: cidr, - DefaultMTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } - return tun, nil + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) deviceBytes() (o [16]byte) { @@ -170,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -194,10 +196,10 @@ func (t *tun) Activate() error { unix.SOCK_DGRAM, unix.IPPROTO_IP, ) - if err != nil { return err } + defer unix.Close(s) fd := uintptr(s) @@ -268,6 +270,7 @@ func (t *tun) Activate() error { if linkAddr == nil { return fmt.Errorf("unable to discover link_addr for tun interface") } + t.linkAddr = linkAddr copy(routeAddr.IP[:], addr[:]) copy(maskAddr.IP[:], mask[:]) @@ -286,35 +289,50 @@ func (t *tun) Activate() error { } // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil { - // We don't allow route MTUs so only install routes with a via - continue - } + return t.addRoutes(false) +} - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { - if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") - } else { - return err - } + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } - // TODO how to set metric + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } } return nil } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) + ok, r := t.routeTree.Load().MostSpecificContains(ip) + if ok { + return r } return 0 @@ -348,6 +366,88 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { return nil, nil } +func (t *tun) addRoutes(logErrors bool) error { + routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + + defer func() { + unix.Shutdown(routeSock, unix.SHUT_RDWR) + err := unix.Close(routeSock) + if err != nil { + t.l.WithError(err).Error("failed to close AF_ROUTE socket") + } + }() + + routeAddr := &netroute.Inet4Addr{} + maskAddr := &netroute.Inet4Addr{} + routes := *t.Routes.Load() + for _, r := range routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + copy(routeAddr.IP[:], r.Cidr.IP.To4()) + copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + + err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + if err != nil { + if errors.Is(err, unix.EEXIST) { + t.l.WithField("route", r.Cidr). + Warnf("unable to add unsafe_route, identical route already exists") + } else { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } else { + t.l.WithField("route", r).Info("Added route") + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + + defer func() { + unix.Shutdown(routeSock, unix.SHUT_RDWR) + err := unix.Close(routeSock) + if err != nil { + t.l.WithError(err).Error("failed to close AF_ROUTE socket") + } + }() + + routeAddr := &netroute.Inet4Addr{} + maskAddr := &netroute.Inet4Addr{} + + for _, r := range routes { + if !r.Install { + continue + } + + copy(routeAddr.IP[:], r.Cidr.IP.To4()) + copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + + err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { r := netroute.RouteMessage{ Version: unix.RTM_VERSION, @@ -373,6 +473,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } +func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { + r := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + Addrs: []netroute.Addr{ + unix.RTAX_DST: addr, + unix.RTAX_GATEWAY: link, + unix.RTAX_NETMASK: mask, + }, + } + + data, err := r.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + func (t *tun) Read(to []byte) (int, error) { buf := make([]byte, len(to)+4) diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index b7f7273c2..e1e4ede67 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -1,7 +1,6 @@ package overlay import ( - "encoding/binary" "fmt" "io" "net" @@ -75,38 +74,15 @@ func (t *disabledTun) Read(b []byte) (int, error) { } func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { - // Return early if this is not a simple ICMP Echo Request - //TODO: make constants out of these - if !(len(b) >= 28 && len(b) <= 9001 && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) { + out := make([]byte, len(b)) + out = iputil.CreateICMPEchoResponse(b, out) + if out == nil { return false } - // We don't support fragmented packets - if b[7] != 0 || (b[6]&0x2F != 0) { - return false - } - - buf := make([]byte, len(b)) - copy(buf, b) - - // Swap dest / src IPs and recalculate checksum - ipv4 := buf[0:20] - copy(ipv4[12:16], b[16:20]) - copy(ipv4[16:20], b[12:16]) - ipv4[10] = 0 - ipv4[11] = 0 - binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4)) - - // Change type to ICMP Echo Reply and recalculate checksum - icmp := buf[20:] - icmp[0] = 0 - icmp[2] = 0 - icmp[3] = 0 - binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp)) - // attempt to write it, but don't block select { - case t.read <- buf: + case t.read <- out: default: t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") } @@ -154,22 +130,3 @@ func (p prettyPacket) String() string { return s.String() } - -func ipChecksum(b []byte) uint16 { - var c uint32 - sz := len(b) - 1 - - for i := 0; i < sz; i += 2 { - c += uint32(b[i]) << 8 - c += uint32(b[i+1]) - } - if sz%2 == 0 { - c += uint32(b[sz]) << 8 - } - - for (c >> 16) > 0 { - c = (c & 0xffff) + (c >> 16) - } - - return ^uint16(c) -} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 0a3f722c0..3b1b80f1a 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -4,28 +4,54 @@ package overlay import ( + "bytes" + "errors" "fmt" "io" + "io/fs" "net" "os" "os/exec" - "regexp" "strconv" - "strings" + "sync/atomic" + "syscall" + "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" ) -var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) +const ( + // FIODGNAME is defined in sys/sys/filio.h on FreeBSD + // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) + FIODGNAME = 0x80106678 +) + +type fiodgnameArg struct { + length int32 + pad [4]byte + buf unsafe.Pointer +} + +type ifreqRename struct { + Name [16]byte + Data uintptr +} + +type ifreqDestroy struct { + Name [16]byte + pad [16]byte +} type tun struct { Device string cidr *net.IPNet MTU int - Routes []Route - routeTree *cidr.Tree4 + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] l *logrus.Logger io.ReadWriteCloser @@ -33,67 +59,174 @@ type tun struct { func (t *tun) Close() error { if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() + if err := t.ReadWriteCloser.Close(); err != nil { + return err + } + + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifreq := ifreqDestroy{Name: t.deviceBytes()} + + // Destroy the interface + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + return err } + return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { + // Try to open existing tun device + var file *os.File + var err error + deviceName := c.GetString("tun.dev", "") + if deviceName != "" { + file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + } + if errors.Is(err, fs.ErrNotExist) || deviceName == "" { + // If the device doesn't already exist, request a new one and rename it + file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) + } if err != nil { return nil, err } - if strings.HasPrefix(deviceName, "/dev/") { - deviceName = strings.TrimPrefix(deviceName, "/dev/") + rawConn, err := file.SyscallConn() + if err != nil { + return nil, fmt.Errorf("SyscallConn: %v", err) } - if !deviceNameRE.MatchString(deviceName) { - return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`") + + var name [16]byte + var ctrlErr error + rawConn.Control(func(fd uintptr) { + // Read the name of the interface + arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} + ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg))) + }) + if ctrlErr != nil { + return nil, err } - return &tun{ - Device: deviceName, - cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, - l: l, - }, nil -} -func (t *tun) Activate() error { - var err error - t.ReadWriteCloser, err = os.OpenFile("/dev/"+t.Device, os.O_RDWR, 0) + ifName := string(bytes.TrimRight(name[:], "\x00")) + if deviceName == "" { + deviceName = ifName + } + + // If the name doesn't match the desired interface name, rename it now + if ifName != deviceName { + s, err := syscall.Socket( + syscall.AF_INET, + syscall.SOCK_DGRAM, + syscall.IPPROTO_IP, + ) + if err != nil { + return nil, err + } + defer syscall.Close(s) + + fd := uintptr(s) + + var fromName [16]byte + var toName [16]byte + copy(fromName[:], ifName) + copy(toName[:], deviceName) + + ifrr := ifreqRename{ + Name: fromName, + Data: uintptr(unsafe.Pointer(&toName)), + } + + // Set the device name + ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + } + + t := &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) if err != nil { - return fmt.Errorf("activate failed: %v", err) + return nil, err } + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *tun) Activate() error { + var err error // TODO use syscalls instead of exec.Command - t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) - if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil { + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) - if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil { + + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) } - t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) - if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil { + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } + // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil { - // We don't allow route MTUs so only install routes with a via - continue + return t.addRoutes(false) +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } - t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) - if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil { - return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } @@ -101,12 +234,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { @@ -120,3 +249,50 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } + +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *tun) deviceBytes() (o [16]byte) { + for i, c := range t.Device { + o[i] = byte(c) + } + return +} diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 59c190e0b..ba15d665e 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -10,48 +10,79 @@ import ( "net" "os" "sync" + "sync/atomic" "syscall" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser cidr *net.IPNet - routeTree *cidr.Tree4 + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + l *logrus.Logger } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { + file := os.NewFile(uintptr(deviceFd), "/dev/tun") + t := &tun{ + cidr: cidr, + ReadWriteCloser: &tunReadCloser{f: file}, + l: l, + } + + err := t.reload(c, true) if err != nil { return nil, err } - file := os.NewFile(uintptr(deviceFd), "/dev/tun") - return &tun{ - cidr: cidr, - ReadWriteCloser: &tunReadCloser{f: file}, - routeTree: routeTree, - }, nil + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) Activate() error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err } - return 0 + // Teach nebula how to handle the routes + t.Routes.Store(&routes) + t.routeTree.Store(routeTree) + return nil +} + +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } // The following is hoisted up from water, we do this so we can inject our own fd on iOS diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 1406438b0..1f6580edb 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,31 +4,41 @@ package overlay import ( + "bytes" "fmt" "io" "net" "os" "strings" + "sync/atomic" "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser - fd int - Device string - cidr *net.IPNet - MaxMTU int - DefaultMTU int - TXQueueLen int - Routes []Route - routeTree *cidr.Tree4 - l *logrus.Logger + fd int + Device string + cidr *net.IPNet + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr + + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeChan chan struct{} + useSystemRoutes bool + + l *logrus.Logger } type ifReq struct { @@ -37,14 +47,6 @@ type ifReq struct { pad [8]byte } -func ioctl(a1, a2, a3 uintptr) error { - _, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3) - if errno != 0 { - return errno - } - return nil -} - type ifreqAddr struct { Name [16]byte Addr unix.RawSockaddrInet4 @@ -63,28 +65,20 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, true) +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { + file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + + t, err := newTunGeneric(c, l, file, cidr) if err != nil { return nil, err } - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + t.Device = "tun0" - return &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), - Device: "tun0", - cidr: cidr, - DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - Routes: routes, - routeTree: routeTree, - l: l, - }, nil + return t, nil } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -95,42 +89,111 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } - copy(req.Name[:], deviceName) + copy(req.Name[:], c.GetString("tun.dev", "")) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") + t, err := newTunGeneric(c, l, file, cidr) + if err != nil { + return nil, err + } - maxMTU := defaultMTU - for _, r := range routes { + t.Device = name + + return t, nil +} + +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { + t := &tun{ + ReadWriteCloser: file, + fd: int(file.Fd()), + cidr: cidr, + TXQueueLen: c.GetInt("tun.tx_queue", 500), + useSystemRoutes: c.GetBool("tun.use_system_route_table", false), + l: l, + } + + err := t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *tun) reload(c *config.C, initial bool) error { + routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !routeChange && !c.HasChanged("tun.mtu") { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, true) + if err != nil { + return err + } + + oldDefaultMTU := t.DefaultMTU + oldMaxMTU := t.MaxMTU + newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU) + newMaxMTU := newDefaultMTU + for i, r := range routes { if r.MTU == 0 { - r.MTU = defaultMTU + routes[i].MTU = newDefaultMTU } - if r.MTU > maxMTU { - maxMTU = r.MTU + if r.MTU > t.MaxMTU { + newMaxMTU = r.MTU } } - routeTree, err := makeRouteTree(l, routes, true) - if err != nil { - return nil, err + t.MaxMTU = newMaxMTU + t.DefaultMTU = newDefaultMTU + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + if oldMaxMTU != newMaxMTU { + t.setMTU() + t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + } + + if oldDefaultMTU != newDefaultMTU { + err := t.setDefaultRoute() + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } + } + + // Remove first, if the system removes a wanted route hopefully it will be re-added next + t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // This should never be called since addRoutes should log its own errors in a reload condition + util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l) + } } - return &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), - Device: name, - cidr: cidr, - MaxMTU: maxMTU, - DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - Routes: routes, - routeTree: routeTree, - l: l, - }, nil + return nil } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { @@ -152,12 +215,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *tun) Write(b []byte) (int, error) { @@ -183,16 +242,20 @@ func (t *tun) Write(b []byte) (int, error) { } } -func (t tun) deviceBytes() (o [16]byte) { +func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } -func (t tun) Activate() error { +func (t *tun) Activate() error { devName := t.deviceBytes() + if t.useSystemRoutes { + t.watchRoutes() + } + var addr, mask [4]byte copy(addr[:], t.cidr.IP.To4()) @@ -206,7 +269,7 @@ func (t tun) Activate() error { if err != nil { return err } - fd := uintptr(s) + t.ioctlFd = uintptr(s) ifra := ifreqAddr{ Name: devName, @@ -217,52 +280,76 @@ func (t tun) Activate() error { } // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { return fmt.Errorf("failed to set tun address: %s", err) } // Set the device network ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { return fmt.Errorf("failed to set tun netmask: %s", err) } // Set the device name ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } - // Set the MTU on the device - ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)} - if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { - // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") - } + // Setup our default MTU + t.setMTU() // Set the transmit queue length ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss t.l.WithError(err).Error("Failed to set tun tx queue length") } // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - // Set the routes link, err := netlink.LinkByName(t.Device) if err != nil { return fmt.Errorf("failed to get tun device link: %s", err) } + t.deviceIndex = link.Attrs().Index + + if err = t.setDefaultRoute(); err != nil { + return err + } + + // Set the routes + if err = t.addRoutes(false); err != nil { + return err + } + + // Run the interface + ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to run tun device: %s", err) + } + return nil +} + +func (t *tun) setMTU() { + // Set the MTU on the device + ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} + if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { + // This is currently a non fatal condition because the route table must have the MTU set appropriately as well + t.l.WithError(err).Error("Failed to set tun mtu") + } +} + +func (t *tun) setDefaultRoute() error { // Default route dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} nr := netlink.Route{ - LinkIndex: link.Attrs().Index, + LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), @@ -272,15 +359,24 @@ func (t tun) Activate() error { Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } - err = netlink.RouteReplace(&nr) + err := netlink.RouteReplace(&nr) if err != nil { return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) } + return nil +} + +func (t *tun) addRoutes(logErrors bool) error { // Path routes - for _, r := range t.Routes { + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Install { + continue + } + nr := netlink.Route{ - LinkIndex: link.Attrs().Index, + LinkIndex: t.deviceIndex, Dst: r.Cidr, MTU: r.MTU, AdvMSS: t.advMSS(r), @@ -291,21 +387,49 @@ func (t tun) Activate() error { nr.Priority = r.Metric } - err = netlink.RouteAdd(&nr) + err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err) + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") } } - // Run the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to run tun device: %s", err) - } - return nil } +func (t *tun) removeRoutes(routes []Route) { + for _, r := range routes { + if !r.Install { + continue + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: r.Cidr, + MTU: r.MTU, + AdvMSS: t.advMSS(r), + Scope: unix.RT_SCOPE_LINK, + } + + if r.Metric > 0 { + nr.Priority = r.Metric + } + + err := netlink.RouteDel(&nr) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } +} + func (t *tun) Cidr() *net.IPNet { return t.cidr } @@ -314,7 +438,7 @@ func (t *tun) Name() string { return t.Device } -func (t tun) advMSS(r Route) int { +func (t *tun) advMSS(r Route) int { mtu := r.MTU if r.MTU == 0 { mtu = t.DefaultMTU @@ -326,3 +450,87 @@ func (t tun) advMSS(r Route) int { } return 0 } + +func (t *tun) watchRoutes() { + rch := make(chan netlink.RouteUpdate) + doneChan := make(chan struct{}) + + if err := netlink.RouteSubscribe(rch, doneChan); err != nil { + t.l.WithError(err).Errorf("failed to subscribe to system route changes") + return + } + + t.routeChan = doneChan + + go func() { + for { + select { + case r := <-rch: + t.updateRoutes(r) + case <-doneChan: + // netlink.RouteSubscriber will close the rch for us + return + } + } + }() +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + if r.Gw == nil { + // Not a gateway route, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") + return + } + + if !t.cidr.Contains(r.Gw) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + return + } + + if x := r.Dst.IP.To4(); x == nil { + // Nebula only handles ipv4 on the overlay currently + t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + return + } + + newTree := cidr.NewTree4[iputil.VpnIp]() + if r.Type == unix.RTM_NEWROUTE { + for _, oldR := range t.routeTree.Load().List() { + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") + newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + + } else { + gw := iputil.Ip2VpnIp(r.Gw) + for _, oldR := range t.routeTree.Load().List() { + if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { + // This is the record to delete + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") + continue + } + + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + } + + t.routeTree.Store(newTree) +} + +func (t *tun) Close() error { + if t.routeChan != nil { + close(t.routeChan) + } + + if t.ReadWriteCloser != nil { + t.ReadWriteCloser.Close() + } + + if t.ioctlFd > 0 { + os.NewFile(t.ioctlFd, "ioctlFd").Close() + } + + return nil +} diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go index 6c2043d97..1c1842da5 100644 --- a/overlay/tun_linux_test.go +++ b/overlay/tun_linux_test.go @@ -7,19 +7,19 @@ import "testing" var runAdvMSSTests = []struct { name string - tun tun + tun *tun r Route expected int }{ // Standard case, default MTU is the device max MTU - {"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, - {"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, - {"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, + {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, + {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, + {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, // Case where we have a route MTU set higher than the default - {"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, - {"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, - {"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, + {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, + {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, + {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, } func TestTunAdvMSS(t *testing.T) { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go new file mode 100644 index 000000000..cc0216fe9 --- /dev/null +++ b/overlay/tun_netbsd.go @@ -0,0 +1,233 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "strconv" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" +) + +type ifreqDestroy struct { + Name [16]byte + pad [16]byte +} + +type tun struct { + Device string + cidr *net.IPNet + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + l *logrus.Logger + + io.ReadWriteCloser +} + +func (t *tun) Close() error { + if t.ReadWriteCloser != nil { + if err := t.ReadWriteCloser.Close(); err != nil { + return err + } + + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifreq := ifreqDestroy{Name: t.deviceBytes()} + + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + + return err + } + return nil +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { + return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") +} + +var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) + +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { + // Try to open tun device + var file *os.File + var err error + deviceName := c.GetString("tun.dev", "") + if deviceName == "" { + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") + } + if !deviceNameRE.MatchString(deviceName) { + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") + } + + file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + t := &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *tun) Activate() error { + var err error + + // TODO use syscalls instead of exec.Command + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add': %s", err) + } + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + // Unsafe path routes + return t.addRoutes(false) +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r +} + +func (t *tun) Cidr() *net.IPNet { + return t.cidr +} + +func (t *tun) Name() string { + return t.Device +} + +func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") +} + +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *tun) deviceBytes() (o [16]byte) { + for i, c := range t.Device { + o[i] = byte(c) + } + return +} diff --git a/overlay/tun_notwin.go b/overlay/tun_notwin.go new file mode 100644 index 000000000..2fab9274b --- /dev/null +++ b/overlay/tun_notwin.go @@ -0,0 +1,14 @@ +//go:build !windows +// +build !windows + +package overlay + +import "syscall" + +func ioctl(a1, a2, a3 uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3) + if errno != 0 { + return errno + } + return nil +} diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go new file mode 100644 index 000000000..53f57b137 --- /dev/null +++ b/overlay/tun_openbsd.go @@ -0,0 +1,245 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "strconv" + "sync/atomic" + "syscall" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" +) + +type tun struct { + Device string + cidr *net.IPNet + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + l *logrus.Logger + + io.ReadWriteCloser + + // cache out buffer since we need to prepend 4 bytes for tun metadata + out []byte +} + +func (t *tun) Close() error { + if t.ReadWriteCloser != nil { + return t.ReadWriteCloser.Close() + } + + return nil +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { + return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") +} + +var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) + +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { + deviceName := c.GetString("tun.dev", "") + if deviceName == "" { + return nil, fmt.Errorf("a device name in the format of tunN must be specified") + } + + if !deviceNameRE.MatchString(deviceName) { + return nil, fmt.Errorf("a device name in the format of tunN must be specified") + } + + file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + t := &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *tun) Activate() error { + var err error + // TODO use syscalls instead of exec.Command + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add': %s", err) + } + + // Unsafe path routes + return t.addRoutes(false) +} + +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r +} + +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *tun) Cidr() *net.IPNet { + return t.cidr +} + +func (t *tun) Name() string { + return t.Device +} + +func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") +} + +func (t *tun) Read(to []byte) (int, error) { + buf := make([]byte, len(to)+4) + + n, err := t.ReadWriteCloser.Read(buf) + + copy(to, buf[4:]) + return n - 4, err +} + +// Write is only valid for single threaded use +func (t *tun) Write(from []byte) (int, error) { + buf := t.out + if cap(buf) < len(from)+4 { + buf = make([]byte, len(from)+4) + t.out = buf + } + buf = buf[:len(from)+4] + + if len(from) == 0 { + return 0, syscall.EIO + } + + // Determine the IP Family for the NULL L2 Header + ipVer := from[0] >> 4 + if ipVer == 4 { + buf[3] = syscall.AF_INET + } else if ipVer == 6 { + buf[3] = syscall.AF_INET6 + } else { + return 0, fmt.Errorf("unable to determine IP version from packet") + } + + copy(buf[4:], from) + + n, err := t.ReadWriteCloser.Write(buf) + return n - 4, err +} diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index a4ee20ba7..383398322 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -8,9 +8,11 @@ import ( "io" "net" "os" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" ) @@ -18,21 +20,26 @@ type TestTun struct { Device string cidr *net.IPNet Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger + closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, cidr, true) + if err != nil { + return nil, err + } routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err } return &TestTun{ - Device: deviceName, + Device: c.GetString("tun.dev", ""), cidr: cidr, Routes: routes, routeTree: routeTree, @@ -42,7 +49,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes }, nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -50,8 +57,12 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (t *TestTun) Send(packet []byte) { - if t.l.Level >= logrus.InfoLevel { - t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet") + if t.closed.Load() { + return + } + + if t.l.Level >= logrus.DebugLevel { + t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") } t.rxPackets <- packet } @@ -77,12 +88,8 @@ func (t *TestTun) Get(block bool) []byte { //********************************************************************************************************************// func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *TestTun) Activate() error { @@ -98,6 +105,10 @@ func (t *TestTun) Name() string { } func (t *TestTun) Write(b []byte) (n int, err error) { + if t.closed.Load() { + return 0, io.ErrClosedPipe + } + packet := make([]byte, len(b), len(b)) copy(packet, b) t.TxPackets <- packet @@ -105,7 +116,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) { } func (t *TestTun) Close() error { - close(t.rxPackets) + if t.closed.CompareAndSwap(false, true) { + close(t.rxPackets) + close(t.TxPackets) + } return nil } diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index 8e2e571bd..a1acd2b25 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -6,10 +6,13 @@ import ( "net" "os/exec" "strconv" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" "github.com/songgao/water" ) @@ -17,25 +20,34 @@ type waterTun struct { Device string cidr *net.IPNet MTU int - Routes []Route - routeTree *cidr.Tree4 - + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + l *logrus.Logger + f *net.Interface *water.Interface } -func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) { - routeTree, err := makeRouteTree(l, routes, false) +func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { + // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() + t := &waterTun{ + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err := t.reload(c, true) if err != nil { return nil, err } - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - return &waterTun{ - cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, - }, nil + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *waterTun) Activate() error { @@ -74,35 +86,105 @@ func (t *waterTun) Activate() error { return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) } - iface, err := net.InterfaceByName(t.Device) + t.f, err = net.InterfaceByName(t.Device) if err != nil { return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) } - for _, r := range t.Routes { - if r.Via == nil { + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *waterTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to set routes", err, t.l) + } else { + for _, r := range findRemovedRoutes(routes, *oldRoutes) { + t.l.WithField("route", r).Info("Removed route") + } + } + } + + return nil +} + +func (t *waterTun) addRoutes(logErrors bool) error { + // Path routes + routes := *t.Routes.Load() + for _, r := range routes { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - err = exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric), + err := exec.Command( + "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), ).Run() + if err != nil { - return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err) + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") } } return nil } -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) +func (t *waterTun) removeRoutes(routes []Route) { + for _, r := range routes { + if !r.Install { + continue + } + + err := exec.Command( + "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), + ).Run() + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } } +} - return 0 +func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *waterTun) Cidr() *net.IPNet { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index e35e98bbd..f85ee9cee 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -12,13 +12,14 @@ import ( "syscall" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") @@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int } if useWintun { - device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes) + device, err := newWinTun(c, l, cidr, multiqueue) if err != nil { return nil, fmt.Errorf("create Wintun interface failed, %w", err) } return device, nil } - device, err := newWaterTun(l, cidr, defaultMTU, routes) + device, err := newWaterTun(c, l, cidr, multiqueue) if err != nil { return nil, fmt.Errorf("create wintap driver failed, %w", err) } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 0538849be..197e3a717 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -6,11 +6,14 @@ import ( "io" "net" "net/netip" + "sync/atomic" "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -23,8 +26,9 @@ type winTun struct { cidr *net.IPNet prefix netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4 + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + l *logrus.Logger tun *wintun.NativeTun } @@ -48,76 +52,148 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { + deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } - tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) + prefix, err := iputil.ToNetIpPrefix(*cidr) if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) + return nil, err } - routeTree, err := makeRouteTree(l, routes, false) + t := &winTun{ + Device: deviceName, + cidr: cidr, + prefix: prefix, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) if err != nil { return nil, err } - prefix, err := iputil.ToNetIpPrefix(*cidr) + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { - return nil, err + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + return nil, fmt.Errorf("create TUN device failed: %w", err) + } } + t.tun = tunDevice.(*wintun.NativeTun) - return &winTun{ - Device: deviceName, - cidr: cidr, - prefix: prefix, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) - tun: tunDevice.(*wintun.NativeTun), - }, nil + return t, nil +} + +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil } func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil { + err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) + if err != nil { return fmt.Errorf("failed to set address: %w", err) } + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() foundDefault4 := false - routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) - for _, r := range t.Routes { - if r.Via == nil { + for _, r := range routes { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { - foundDefault4 = true - } - } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) if err != nil { - return err + retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } } // Add our unsafe route - routes = append(routes, &winipcfg.RouteData{ - Destination: prefix, - NextHop: r.Via.ToNetIpAddr(), - Metric: uint32(r.Metric), - }) - } + err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") + } - if err := luid.AddRoutes(routes); err != nil { - return fmt.Errorf("failed to add routes: %w", err) + if !foundDefault4 { + if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { + foundDefault4 = true + } + } } ipif, err := luid.IPInterface(windows.AF_INET) @@ -134,17 +210,36 @@ func (t *winTun) Activate() error { if err := ipif.Set(); err != nil { return fmt.Errorf("failed to set ip interface: %w", err) } - return nil } -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + + for _, r := range routes { + if !r.Install { + continue + } + + prefix, err := iputil.ToNetIpPrefix(*r.Cidr) + if err != nil { + t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix") + continue + } + + err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr()) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } } + return nil +} - return 0 +func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *winTun) Cidr() *net.IPNet { diff --git a/overlay/user.go b/overlay/user.go new file mode 100644 index 000000000..9d819ae99 --- /dev/null +++ b/overlay/user.go @@ -0,0 +1,63 @@ +package overlay + +import ( + "io" + "net" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return NewUserDevice(tunCidr) +} + +func NewUserDevice(tunCidr *net.IPNet) (Device, error) { + // these pipes guarantee each write/read will match 1:1 + or, ow := io.Pipe() + ir, iw := io.Pipe() + return &UserDevice{ + tunCidr: tunCidr, + outboundReader: or, + outboundWriter: ow, + inboundReader: ir, + inboundWriter: iw, + }, nil +} + +type UserDevice struct { + tunCidr *net.IPNet + + outboundReader *io.PipeReader + outboundWriter *io.PipeWriter + + inboundReader *io.PipeReader + inboundWriter *io.PipeWriter +} + +func (d *UserDevice) Activate() error { + return nil +} +func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return d, nil +} + +func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { + return d.inboundReader, d.outboundWriter +} + +func (d *UserDevice) Read(p []byte) (n int, err error) { + return d.outboundReader.Read(p) +} +func (d *UserDevice) Write(p []byte) (n int, err error) { + return d.inboundWriter.Write(p) +} +func (d *UserDevice) Close() error { + d.inboundWriter.Close() + d.outboundWriter.Close() + return nil +} diff --git a/pki.go b/pki.go new file mode 100644 index 000000000..91478ce51 --- /dev/null +++ b/pki.go @@ -0,0 +1,248 @@ +package nebula + +import ( + "errors" + "fmt" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" +) + +type PKI struct { + cs atomic.Pointer[CertState] + caPool atomic.Pointer[cert.NebulaCAPool] + l *logrus.Logger +} + +type CertState struct { + Certificate *cert.NebulaCertificate + RawCertificate []byte + RawCertificateNoKey []byte + PublicKey []byte + PrivateKey []byte +} + +func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { + pki := &PKI{l: l} + err := pki.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + rErr := pki.reload(c, false) + if rErr != nil { + util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) + } + }) + + return pki, nil +} + +func (p *PKI) GetCertState() *CertState { + return p.cs.Load() +} + +func (p *PKI) GetCAPool() *cert.NebulaCAPool { + return p.caPool.Load() +} + +func (p *PKI) reload(c *config.C, initial bool) error { + err := p.reloadCert(c, initial) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + err = p.reloadCAPool(c) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + return nil +} + +func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { + cs, err := newCertStateFromConfig(c) + if err != nil { + return util.NewContextualError("Could not load client cert", nil, err) + } + + if !initial { + // did IP in cert change? if so, don't set + currentCert := p.cs.Load().Certificate + oldIPs := currentCert.Details.Ips + newIPs := cs.Certificate.Details.Ips + if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + return util.NewContextualError( + "IP in new cert was different from old", + m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + nil, + ) + } + } + + p.cs.Store(cs) + if initial { + p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + } else { + p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + } + return nil +} + +func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { + caPool, err := loadCAPoolFromConfig(p.l, c) + if err != nil { + return util.NewContextualError("Failed to load ca from config", nil, err) + } + + p.caPool.Store(caPool) + p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + return nil +} + +func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { + // Marshal the certificate to ensure it is valid + rawCertificate, err := certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) + } + + publicKey := certificate.Details.PublicKey + cs := &CertState{ + RawCertificate: rawCertificate, + Certificate: certificate, + PrivateKey: privateKey, + PublicKey: publicKey, + } + + cs.Certificate.Details.PublicKey = nil + rawCertNoKey, err := cs.Certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + } + cs.RawCertificateNoKey = rawCertNoKey + // put public key back + cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil +} + +func newCertStateFromConfig(c *config.C) (*CertState, error) { + var pemPrivateKey []byte + var err error + + privPathOrPEM := c.GetString("pki.key", "") + if privPathOrPEM == "" { + return nil, errors.New("no pki.key path or PEM data provided") + } + + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + } + + rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + + var rawCert []byte + + pubPathOrPEM := c.GetString("pki.cert", "") + if pubPathOrPEM == "" { + return nil, errors.New("no pki.cert path or PEM data provided") + } + + if strings.Contains(pubPathOrPEM, "-----BEGIN") { + rawCert = []byte(pubPathOrPEM) + pubPathOrPEM = "" + + } else { + rawCert, err = os.ReadFile(pubPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) + } + } + + nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + } + + if nebulaCert.Expired(time.Now()) { + return nil, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(nebulaCert.Details.Ips) == 0 { + return nil, fmt.Errorf("no IPs encoded in certificate") + } + + if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + + return newCertState(nebulaCert, rawKey) +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { + var rawCA []byte + var err error + + caPathOrPEM := c.GetString("pki.ca", "") + if caPathOrPEM == "" { + return nil, errors.New("no pki.ca path or PEM data provided") + } + + if strings.Contains(caPathOrPEM, "-----BEGIN") { + rawCA = []byte(caPathOrPEM) + + } else { + rawCA, err = os.ReadFile(caPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) + } + } + + caPool, err := cert.NewCAPoolFromBytes(rawCA) + if errors.Is(err, cert.ErrExpired) { + var expired int + for _, crt := range caPool.CAs { + if crt.Expired(time.Now()) { + expired++ + l.WithField("cert", crt).Warn("expired certificate present in CA pool") + } + } + + if expired >= len(caPool.CAs) { + return nil, errors.New("no valid CA certificates present") + } + + } else if err != nil { + return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) + } + + for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { + l.WithField("fingerprint", fp).Info("Blocklisting cert") + caPool.BlocklistFingerprint(fp) + } + + return caPool, nil +} diff --git a/punchy.go b/punchy.go index 1ecf7c511..2034405a7 100644 --- a/punchy.go +++ b/punchy.go @@ -9,10 +9,12 @@ import ( ) type Punchy struct { - punch atomic.Bool - respond atomic.Bool - delay atomic.Int64 - l *logrus.Logger + punch atomic.Bool + respond atomic.Bool + delay atomic.Int64 + respondDelay atomic.Int64 + punchEverything atomic.Bool + l *logrus.Logger } func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { @@ -37,6 +39,12 @@ func (p *Punchy) reload(c *config.C, initial bool) { } p.punch.Store(yes) + if yes { + p.l.Info("punchy enabled") + } else { + p.l.Info("punchy disabled") + } + } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") @@ -65,6 +73,20 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.l.Infof("punchy.delay changed to %s", p.GetDelay()) } } + + if initial || c.HasChanged("punchy.target_all_remotes") { + p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) + if !initial { + p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + } + } + + if initial || c.HasChanged("punchy.respond_delay") { + p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) + if !initial { + p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) + } + } } func (p *Punchy) GetPunch() bool { @@ -78,3 +100,11 @@ func (p *Punchy) GetRespond() bool { func (p *Punchy) GetDelay() time.Duration { return (time.Duration)(p.delay.Load()) } + +func (p *Punchy) GetRespondDelay() time.Duration { + return (time.Duration)(p.respondDelay.Load()) +} + +func (p *Punchy) GetTargetEverything() bool { + return p.punchEverything.Load() +} diff --git a/punchy_test.go b/punchy_test.go index 0aa9b6234..bedd2b266 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -18,6 +18,7 @@ func TestNewPunchyFromConfig(t *testing.T) { assert.Equal(t, false, p.GetPunch()) assert.Equal(t, false, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) + assert.Equal(t, 5*time.Second, p.GetRespondDelay()) // punchy deprecation c.Settings["punchy"] = true @@ -44,6 +45,11 @@ func TestNewPunchyFromConfig(t *testing.T) { c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetDelay()) + + // punchy.respond_delay + c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"} + p = NewPunchyFromConfig(l, c) + assert.Equal(t, time.Minute, p.GetRespondDelay()) } func TestPunchy_reload(t *testing.T) { diff --git a/relay_manager.go b/relay_manager.go index 95807bd25..7aa06ccb4 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -61,6 +61,11 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput _, inRelays := hm.Relays[index] if !inRelays { + // Avoid standing up a relay that can't be used since only the primary hostinfo + // will be pointed to by the relay logic + //TODO: if there was an existing primary and it had relay state, should we merge? + hm.unlockedMakePrimary(relayHostInfo) + hm.Relays[index] = relayHostInfo newRelay := Relay{ Type: relayType, @@ -83,17 +88,14 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput // EstablishRelay updates a Requested Relay to become an Established Relay, which can pass traffic. func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { - relay, ok := relayHostInfo.relayState.QueryRelayForByIdx(m.InitiatorRelayIndex) + relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relayHostInfo": relayHostInfo.vpnIp, + rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, "initiatorRelayIndex": m.InitiatorRelayIndex, "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager EstablishRelay relayForByIdx not found") + "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } - // relay deserves some synchronization - relay.RemoteIndex = m.ResponderRelayIndex - relay.State = Established return relay, nil } @@ -111,17 +113,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTarget": iputil.VpnIp(m.RelayToIp), - "initiatorIdx": m.InitiatorRelayIndex, - "responderIdx": m.ResponderRelayIndex, - "hostInfo": h.vpnIp}). + "relayFrom": iputil.VpnIp(m.RelayFromIp), + "relayTo": iputil.VpnIp(m.RelayToIp), + "initiatorRelayIndex": m.InitiatorRelayIndex, + "responderRelayIndex": m.ResponderRelayIndex, + "vpnIp": h.vpnIp}). Info("handleCreateRelayResponse") target := iputil.VpnIp(m.RelayToIp) relay, err := rm.EstablishRelay(h, m) if err != nil { - rm.l.WithError(err).WithField("target", target.String()).Error("Failed to update relay for target") + rm.l.WithError(err).Error("Failed to update relay for relayTo") return } // Do I need to complete the relays now? @@ -129,65 +131,92 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp) - if err != nil { - rm.l.WithError(err).WithField("relayPeerIp", relay.PeerIp).Error("Can't find a HostInfo for peer IP") + peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + if peerHostInfo == nil { + rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) if !ok { - rm.l.WithField("peerIp", peerHostInfo.vpnIp).WithField("target", target.String()).Error("peerRelay does not have Relay state for target IP", peerHostInfo.vpnIp.String(), target.String()) + rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } - peerRelay.State = Established - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: peerRelay.LocalIndex, - InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), - RelayToIp: uint32(target), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marhsal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToVpnIp(header.Control, 0, peerHostInfo.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) + if peerRelay.State == PeerRequested { + peerRelay.State = Established + resp := NebulaControl{ + Type: NebulaControl_CreateRelayResponse, + ResponderRelayIndex: peerRelay.LocalIndex, + InitiatorRelayIndex: peerRelay.RemoteIndex, + RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayToIp: uint32(target), + } + msg, err := resp.Marshal() + if err != nil { + rm.l. + WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(resp.RelayFromIp), + "relayTo": iputil.VpnIp(resp.RelayToIp), + "initiatorRelayIndex": resp.InitiatorRelayIndex, + "responderRelayIndex": resp.ResponderRelayIndex, + "vpnIp": peerHostInfo.vpnIp}). + Info("send CreateRelayResponse") + } } } func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTarget": iputil.VpnIp(m.RelayToIp), - "initiatorIdx": m.InitiatorRelayIndex, - "hostInfo": h.vpnIp}). - Info("handleCreateRelayRequest") + from := iputil.VpnIp(m.RelayFromIp) target := iputil.VpnIp(m.RelayToIp) + + logMsg := rm.l.WithFields(logrus.Fields{ + "relayFrom": from, + "relayTo": target, + "initiatorRelayIndex": m.InitiatorRelayIndex, + "vpnIp": h.vpnIp}) + + logMsg.Info("handleCreateRelayRequest") + // Is the source of the relay me? This should never happen, but did happen due to + // an issue migrating relays over to newly re-handshaked host info objects. + if from == f.myVpnIp { + logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + return + } // Is the target of the relay me? if target == f.myVpnIp { existingRelay, ok := h.relayState.QueryRelayForByIp(from) - addRelay := !ok if ok { - // Clean up existing relay, if this is a new request. - if existingRelay.RemoteIndex != m.InitiatorRelayIndex { - // We got a brand new Relay request, because its index is different than what we saw before. - // Clean up the existing Relay state, and get ready to record new Relay state. - rm.hostmap.RemoveRelay(existingRelay.LocalIndex) - addRelay = true + switch existingRelay.State { + case Requested: + ok = h.relayState.CompleteRelayByIP(from, m.InitiatorRelayIndex) + if !ok { + logMsg.Error("Relay State not found") + return + } + case Established: + if existingRelay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } } - } - if addRelay { + } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { + logMsg.WithError(err).Error("Failed to add relay") return } } relay, ok := h.relayState.QueryRelayForByIp(from) - if ok && m.InitiatorRelayIndex != relay.RemoteIndex { - // Do something, Something happened. + if !ok { + logMsg.Error("Relay State not found") + return } resp := NebulaControl{ @@ -199,22 +228,29 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } msg, err := resp.Marshal() if err != nil { - rm.l. + logMsg. WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(resp.RelayFromIp), + "relayTo": iputil.VpnIp(resp.RelayToIp), + "initiatorRelayIndex": resp.InitiatorRelayIndex, + "responderRelayIndex": resp.ResponderRelayIndex, + "vpnIp": h.vpnIp}). + Info("send CreateRelayResponse") } return } else { // the target is not me. Create a relay to the target, from me. - if rm.GetAmRelay() == false { + if !rm.GetAmRelay() { return } - peer, err := rm.hostmap.QueryVpnIp(target) - if err != nil { + peer := rm.hostmap.QueryVpnIp(target) + if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! - f.getOrHandshake(target) + f.Handshake(target) return } if peer.remote == nil { @@ -223,6 +259,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } sendCreateRequest := false var index uint32 + var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex @@ -247,40 +284,43 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } msg, err := req.Marshal() if err != nil { - rm.l. + logMsg. WithError(err).Error("relayManager Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, target, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(req.RelayFromIp), + "relayTo": iputil.VpnIp(req.RelayToIp), + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnIp": target}). + Info("send CreateRelayRequest") } } // Also track the half-created Relay state just received relay, ok := h.relayState.QueryRelayForByIp(target) if !ok { // Add the relay - state := Requested + state := PeerRequested if targetRelay != nil && targetRelay.State == Established { state = Established } _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state) if err != nil { - rm.l. + logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } } else { - if relay.RemoteIndex != m.InitiatorRelayIndex { - // This is a stale Relay entry for the same tunnel targets. - // Clean up the existing stuff. - rm.RemoveRelay(relay.LocalIndex) - // Add the new relay - _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, Requested) - if err != nil { - return - } - relay, _ = h.relayState.QueryRelayForByIp(target) - } switch relay.State { case Established: + if relay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, @@ -293,7 +333,14 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N rm.l. WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(resp.RelayFromIp), + "relayTo": iputil.VpnIp(resp.RelayToIp), + "initiatorRelayIndex": resp.InitiatorRelayIndex, + "responderRelayIndex": resp.ResponderRelayIndex, + "vpnIp": h.vpnIp}). + Info("send CreateRelayResponse") } case Requested: diff --git a/remote_list.go b/remote_list.go index 4b544f68f..60a1afdaf 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,10 +2,16 @@ package nebula import ( "bytes" + "context" "net" + "net/netip" "sort" + "strconv" "sync" + "sync/atomic" + "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) @@ -55,6 +61,131 @@ type cacheV6 struct { reported []*Ip6AndPort } +type hostnamePort struct { + name string + port uint16 +} + +type hostnamesResults struct { + hostnames []hostnamePort + network string + lookupTimeout time.Duration + cancelFn func() + l *logrus.Logger + ips atomic.Pointer[map[netip.AddrPort]struct{}] +} + +func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { + r := &hostnamesResults{ + hostnames: make([]hostnamePort, len(hostPorts)), + network: network, + lookupTimeout: timeout, + l: l, + } + + // Fastrack IP addresses to ensure they're immediately available for use. + // DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine. + performBackgroundLookup := false + ips := map[netip.AddrPort]struct{}{} + for idx, hostPort := range hostPorts { + + rIp, sPort, err := net.SplitHostPort(hostPort) + if err != nil { + return nil, err + } + + iPort, err := strconv.Atoi(sPort) + if err != nil { + return nil, err + } + + r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} + addr, err := netip.ParseAddr(rIp) + if err != nil { + // This address is a hostname, not an IP address + performBackgroundLookup = true + continue + } + + // Save the IP address immediately + ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{} + } + r.ips.Store(&ips) + + // Time for the DNS lookup goroutine + if performBackgroundLookup { + newCtx, cancel := context.WithCancel(ctx) + r.cancelFn = cancel + ticker := time.NewTicker(d) + go func() { + defer ticker.Stop() + for { + netipAddrs := map[netip.AddrPort]struct{}{} + for _, hostPort := range r.hostnames { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout) + addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) + timeoutCancel() + if err != nil { + l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + continue + } + for _, a := range addrs { + netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + } + } + origSet := r.ips.Load() + different := false + for a := range *origSet { + if _, ok := netipAddrs[a]; !ok { + different = true + break + } + } + if !different { + for a := range netipAddrs { + if _, ok := (*origSet)[a]; !ok { + different = true + break + } + } + } + if different { + l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + r.ips.Store(&netipAddrs) + onUpdate() + } + select { + case <-newCtx.Done(): + return + case <-ticker.C: + continue + } + } + }() + } + + return r, nil +} + +func (hr *hostnamesResults) Cancel() { + if hr != nil && hr.cancelFn != nil { + hr.cancelFn() + } +} + +func (hr *hostnamesResults) GetIPs() []netip.AddrPort { + var retSlice []netip.AddrPort + if hr != nil { + p := hr.ips.Load() + if p != nil { + for k := range *p { + retSlice = append(retSlice, k) + } + } + } + return retSlice +} + // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. // It serves as a local cache of query replies, host update notifications, and locally learned addresses type RemoteList struct { @@ -72,6 +203,9 @@ type RemoteList struct { // For learned addresses, this is the vpnIp that sent the packet cache map[iputil.VpnIp]*cache + hr *hostnamesResults + shouldAdd func(netip.Addr) bool + // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake badRemotes []*udp.Addr @@ -81,14 +215,21 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList() *RemoteList { +func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]*udp.Addr, 0), + relays: make([]*iputil.VpnIp, 0), + cache: make(map[iputil.VpnIp]*cache), + shouldAdd: shouldAdd, } } +func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { + // Cancel any existing hostnamesResults DNS goroutine to release resources + r.hr.Cancel() + r.hr = hr +} + // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { @@ -437,6 +578,17 @@ func (r *RemoteList) unlockedCollect() { } } + dnsAddrs := r.hr.GetIPs() + for _, addr := range dnsAddrs { + if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { + v6 := addr.Addr().As16() + addrs = append(addrs, &udp.Addr{ + IP: v6[:], + Port: addr.Port(), + }) + } + } + r.addrs = addrs r.relays = relays diff --git a/remote_list_test.go b/remote_list_test.go index 21709301b..49aa17191 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,7 +9,7 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 000000000..6d5c8a433 --- /dev/null +++ b/service/listener.go @@ -0,0 +1,36 @@ +package service + +import ( + "io" + "net" +) + +type tcpListener struct { + port uint16 + s *Service + addr *net.TCPAddr + accept chan net.Conn +} + +func (l *tcpListener) Accept() (net.Conn, error) { + conn, ok := <-l.accept + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (l *tcpListener) Close() error { + l.s.mu.Lock() + defer l.s.mu.Unlock() + delete(l.s.mu.listeners, uint16(l.addr.Port)) + + close(l.accept) + + return nil +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.addr +} diff --git a/service/service.go b/service/service.go new file mode 100644 index 000000000..6816be673 --- /dev/null +++ b/service/service.go @@ -0,0 +1,248 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "math" + "net" + "os" + "strings" + "sync" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" + "golang.org/x/sync/errgroup" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const nicID = 1 + +type Service struct { + eg *errgroup.Group + control *nebula.Control + ipstack *stack.Stack + + mu struct { + sync.Mutex + + listeners map[uint16]*tcpListener + } +} + +func New(config *config.C) (*Service, error) { + logger := logrus.New() + logger.Out = os.Stdout + + control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) + if err != nil { + return nil, err + } + control.Start() + + ctx := control.Context() + eg, ctx := errgroup.WithContext(ctx) + s := Service{ + eg: eg, + control: control, + } + s.mu.listeners = map[uint16]*tcpListener{} + + device, ok := control.Device().(*overlay.UserDevice) + if !ok { + return nil, errors.New("must be using user device") + } + + s.ipstack = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, + }) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "") + if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { + return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) + } + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4))) + s.ipstack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + }) + + ipNet := device.Cidr() + pa := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), + Protocol: ipv4.ProtocolNumber, + } + if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ + PEB: stack.CanBePrimaryEndpoint, // zero value default + ConfigType: stack.AddressConfigStatic, // zero value default + }); err != nil { + return nil, fmt.Errorf("error creating IP: %s", err) + } + + const tcpReceiveBufferSize = 0 + const maxInFlightConnectionAttempts = 1024 + tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) + s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + reader, writer := device.Pipe() + + go func() { + <-ctx.Done() + reader.Close() + writer.Close() + }() + + // create Goroutines to forward packets between Nebula and Gvisor + eg.Go(func() error { + buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) + for { + // this will read exactly one packet + n, err := reader.Read(buf) + if err != nil { + return err + } + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), + }) + linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + + if err := ctx.Err(); err != nil { + return err + } + } + }) + eg.Go(func() error { + for { + packet := linkEP.ReadContext(ctx) + if packet == nil { + if err := ctx.Err(); err != nil { + return err + } + continue + } + bufView := packet.ToView() + if _, err := bufView.WriteTo(writer); err != nil { + return err + } + bufView.Release() + } + }) + + return &s, nil +} + +// DialContext dials the provided address. Currently only TCP is supported. +func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +} + +// Listen listens on the provided address. Currently only TCP with wildcard +// addresses are supported. +func (s *Service) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) { + return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP) + } + if addr.Port == 0 { + return nil, errors.New("specific port required, got 0") + } + if addr.Port < 0 || addr.Port >= math.MaxUint16 { + return nil, fmt.Errorf("invalid port %d", addr.Port) + } + port := uint16(addr.Port) + + l := &tcpListener{ + port: port, + s: s, + addr: addr, + accept: make(chan net.Conn), + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.mu.listeners[port]; ok { + return nil, fmt.Errorf("already listening on port %d", port) + } + s.mu.listeners[port] = l + + return l, nil +} + +func (s *Service) Wait() error { + return s.eg.Wait() +} + +func (s *Service) Close() error { + s.control.Stop() + return nil +} + +func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { + endpointID := r.ID() + + s.mu.Lock() + defer s.mu.Unlock() + + l, ok := s.mu.listeners[endpointID.LocalPort] + if !ok { + r.Complete(true) + return + } + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Printf("got error creating endpoint %q", err) + r.Complete(true) + return + } + r.Complete(false) + ep.SocketOptions().SetKeepAlive(true) + + conn := gonet.NewTCPConn(&wq, ep) + l.accept <- conn +} diff --git a/service/service_test.go b/service/service_test.go new file mode 100644 index 000000000..d1909cd15 --- /dev/null +++ b/service/service_test.go @@ -0,0 +1,165 @@ +package service + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + "time" + + "dario.cat/mergo" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "golang.org/x/sync/errgroup" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { + + vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} + copy(vpnIpNet.IP, udpIp) + + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + s, err := New(&c) + if err != nil { + panic(err) + } + return s +} + +func TestService(t *testing.T) { + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": 4243, + }, + }) + b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + "static_host_map": m{ + "10.0.0.1": []string{"localhost:4243"}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + + ln, err := a.Listen("tcp", ":1234") + if err != nil { + t.Fatal(err) + } + var eg errgroup.Group + eg.Go(func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + + t.Log("accepted connection") + + if _, err := conn.Write([]byte("server msg")); err != nil { + return err + } + + t.Log("server: wrote message") + + data := make([]byte, 100) + n, err := conn.Read(data) + if err != nil { + return err + } + data = data[:n] + if !bytes.Equal(data, []byte("client msg")) { + return errors.New("got invalid message from client") + } + t.Log("server: read message") + return conn.Close() + }) + + c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234") + if err != nil { + t.Fatal(err) + } + if _, err := c.Write([]byte("client msg")); err != nil { + t.Fatal(err) + } + + data := make([]byte, 100) + n, err := c.Read(data) + if err != nil { + t.Fatal(err) + } + data = data[:n] + if !bytes.Equal(data, []byte("server msg")) { + t.Fatal("got invalid message from client") + } + + if err := c.Close(); err != nil { + t.Fatal(err) + } + + if err := eg.Wait(); err != nil { + t.Fatal(err) + } +} diff --git a/ssh.go b/ssh.go index ffb1efc5a..9000d3d71 100644 --- a/ssh.go +++ b/ssh.go @@ -3,14 +3,16 @@ package nebula import ( "bytes" "encoding/json" + "errors" "flag" "fmt" - "io/ioutil" "net" "os" "reflect" + "runtime" "runtime/pprof" "sort" + "strconv" "strings" "github.com/sirupsen/logrus" @@ -22,8 +24,9 @@ import ( ) type sshListHostMapFlags struct { - Json bool - Pretty bool + Json bool + Pretty bool + ByIndex bool } type sshPrintCertFlags struct { @@ -92,14 +95,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro } //TODO: no good way to reload this right now - hostKeyFile := c.GetString("sshd.host_key", "") - if hostKeyFile == "" { + hostKeyPathOrKey := c.GetString("sshd.host_key", "") + if hostKeyPathOrKey == "" { return nil, fmt.Errorf("sshd.host_key must be provided") } - hostKeyBytes, err := ioutil.ReadFile(hostKeyFile) - if err != nil { - return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err) + var hostKeyBytes []byte + if strings.Contains(hostKeyPathOrKey, "-----BEGIN") { + hostKeyBytes = []byte(hostKeyPathOrKey) + } else { + hostKeyBytes, err = os.ReadFile(hostKeyPathOrKey) + if err != nil { + return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err) + } } err = ssh.SetHostKey(hostKeyBytes) @@ -170,7 +178,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { +func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -179,10 +187,11 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(hostMap, fs, w) + return sshListHostMap(f.hostMap, fs, w) }, }) @@ -194,10 +203,11 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(pendingHostMap, fs, w) + return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -212,7 +222,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListLighthouseMap(lightHouse, fs, w) + return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -226,7 +236,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", - ShortDescription: "Starts a cpu profile and write output to the provided file", + ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", Callback: sshStartCpuProfile, }) @@ -241,10 +251,22 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", - ShortDescription: "Saves a heap profile to the provided path", + ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", Callback: sshGetHeapProfile, }) + ssh.RegisterCommand(&sshd.Command{ + Name: "mutex-profile-fraction", + ShortDescription: "Gets or sets runtime.SetMutexProfileFraction", + Callback: sshMutexProfileFraction, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "save-mutex-profile", + ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", + Callback: sshGetMutexProfile, + }) + ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", @@ -265,7 +287,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap Name: "version", ShortDescription: "Prints the currently running version of nebula", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshVersion(ifce, fs, a, w) + return sshVersion(f, fs, a, w) }, }) @@ -296,7 +318,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintCert(ifce, fs, a, w) + return sshPrintCert(f, fs, a, w) }, }) @@ -310,7 +332,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintTunnel(ifce, fs, a, w) + return sshPrintTunnel(f, fs, a, w) }, }) @@ -324,7 +346,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintRelays(ifce, fs, a, w) + return sshPrintRelays(f, fs, a, w) }, }) @@ -338,7 +360,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshChangeRemote(ifce, fs, a, w) + return sshChangeRemote(f, fs, a, w) }, }) @@ -352,7 +374,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCloseTunnel(ifce, fs, a, w) + return sshCloseTunnel(f, fs, a, w) }, }) @@ -367,7 +389,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCreateTunnel(ifce, fs, a, w) + return sshCreateTunnel(f, fs, a, w) }, }) @@ -376,19 +398,25 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap ShortDescription: "Query the lighthouses for the provided vpn ip", Help: "This command is asynchronous. Only currently known udp ips will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshQueryLighthouse(ifce, fs, a, w) + return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { //TODO: error return nil } - hm := listHostMap(hostMap) + var hm []ControlHostInfo + if fs.ByIndex { + hm = listHostMapIndexes(hl) + } else { + hm = listHostMapHosts(hl) + } + sort.Slice(hm, func(i, j int) bool { return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 }) @@ -515,7 +543,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp, ifce) + rl := ifce.lightHouse.Query(vpnIp) if rl != nil { cm = rl.CopyCache() } @@ -543,8 +571,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -585,12 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -603,11 +631,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo) + hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) if addr != nil { hostInfo.SetRemote(addr) } - ifce.getOrHandshake(vpnIp) return w.WriteLine("Created") } @@ -642,8 +669,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -672,6 +699,45 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } +func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + rate := runtime.SetMutexProfileFraction(-1) + return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) + } + + newRate, err := strconv.Atoi(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("Invalid argument: %s", a[0])) + } + + oldRate := runtime.SetMutexProfileFraction(newRate) + return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) +} + +func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + return w.WriteLine("No path to write profile provided") + } + + file, err := os.Create(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) + } + defer file.Close() + + mutexProfile := pprof.Lookup("mutex") + if mutexProfile == nil { + return w.WriteLine("Unable to get pprof.Lookup(\"mutex\")") + } + + err = mutexProfile.WriteTo(file, 0) + if err != nil { + return w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err)) + } + + return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) +} + func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) @@ -711,7 +777,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.certState.certificate + cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { parsedIp := net.ParseIP(a[0]) if parsedIp == nil { @@ -723,8 +789,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -809,9 +875,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr for k, v := range relays { ro := RelayOutput{NebulaIp: v.vpnIp} co.Relays = append(co.Relays, &ro) - relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp) - if err != nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err}) + relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + if relayHI == nil { + ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) continue } for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { @@ -847,8 +913,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err == nil { + relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } @@ -883,8 +949,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -893,7 +959,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr enc.SetIndent("", " ") } - return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges)) + return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { diff --git a/stats.go b/stats.go index 03b4d8199..c88c45cc9 100644 --- a/stats.go +++ b/stats.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "runtime" + "strconv" "time" graphite "github.com/cyberdelia/go-metrics-graphite" @@ -105,8 +106,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV Name: "info", Help: "Version information for the Nebula binary", ConstLabels: prometheus.Labels{ - "version": buildVersion, - "goversion": runtime.Version(), + "version": buildVersion, + "goversion": runtime.Version(), + "boringcrypto": strconv.FormatBool(boringEnabled()), }, }) pr.MustRegister(g) diff --git a/test/logger.go b/test/logger.go index 197ab44d2..b5a717d82 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,7 +1,7 @@ package test import ( - "io/ioutil" + "io" "os" "github.com/sirupsen/logrus" @@ -12,7 +12,7 @@ func NewLogger() *logrus.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) return l } diff --git a/udp/conn.go b/udp/conn.go index fa52fe5b3..a2c24a1f1 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,6 +1,7 @@ package udp import ( + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -9,7 +10,6 @@ const MTU = 9001 type EncReader func( addr *Addr, - via interface{}, out []byte, packet []byte, header *header.H, @@ -19,3 +19,33 @@ type EncReader func( q int, localCache firewall.ConntrackCache, ) + +type Conn interface { + Rebind() error + LocalAddr() (*Addr, error) + ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + WriteTo(b []byte, addr *Addr) error + ReloadConfig(c *config.C) + Close() error +} + +type NoopConn struct{} + +func (NoopConn) Rebind() error { + return nil +} +func (NoopConn) LocalAddr() (*Addr, error) { + return nil, nil +} +func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { + return +} +func (NoopConn) WriteTo(_ []byte, _ *Addr) error { + return nil +} +func (NoopConn) ReloadConfig(_ *config.C) { + return +} +func (NoopConn) Close() error { + return nil +} diff --git a/udp/temp.go b/udp/temp.go index 5cc8c1c1a..2efe31d24 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,22 +1,9 @@ package udp import ( - "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" ) -type EncWriter interface { - SendVia(via interface{}, - relay interface{}, - ad, - nb, - out []byte, - nocopy bool, - ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) -} - //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) +type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) diff --git a/udp/udp_all.go b/udp/udp_all.go index a4a462e50..093bf69cc 100644 --- a/udp/udp_all.go +++ b/udp/udp_all.go @@ -64,6 +64,22 @@ func (ua *Addr) Copy() *Addr { return &nu } +type AddrSlice []*Addr + +func (a AddrSlice) Equal(b AddrSlice) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equals(b[i]) { + return false + } + } + + return true +} + func ParseIPAndPort(s string) (net.IP, uint16, error) { rIp, sPort, err := net.SplitHostPort(s) if err != nil { diff --git a/udp/udp_android.go b/udp/udp_android.go index d2812a8af..8d6907488 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -8,9 +8,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -34,6 +39,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go new file mode 100644 index 000000000..785aa6a74 --- /dev/null +++ b/udp/udp_bsd.go @@ -0,0 +1,47 @@ +//go:build (openbsd || freebsd) && !e2e_testing +// +build openbsd freebsd +// +build !e2e_testing + +package udp + +// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig + +import ( + "fmt" + "net" + "syscall" + + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + +func NewListenConfig(multi bool) net.ListenConfig { + return net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + var controlErr error + err := c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) + return + } + }) + if err != nil { + return err + } + if controlErr != nil { + return controlErr + } + } + return nil + }, + } +} + +func (u *GenericConn) Rebind() error { + return nil +} diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 69d0c58e7..08e1b6a80 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -37,11 +42,16 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { - file, err := u.File() +func (u *GenericConn) Rebind() error { + rc, err := u.UDPConn.SyscallConn() if err != nil { return err } - return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) + return rc.Control(func(fd uintptr) { + err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) + if err != nil { + u.l.WithError(err).Error("Failed to rebind udp socket") + } + }) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 0a7c0d90f..1dd6d1de7 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -18,30 +18,32 @@ import ( "github.com/slackhq/nebula/header" ) -type Conn struct { +type GenericConn struct { *net.UDPConn l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +var _ Conn = &GenericConn{} + +func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) - pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) + pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { - return &Conn{UDPConn: uc, l: l}, nil + return &GenericConn{UDPConn: uc, l: l}, nil } return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (uc *Conn) WriteTo(b []byte, addr *Addr) error { - _, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { + _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) return err } -func (uc *Conn) LocalAddr() (*Addr, error) { - a := uc.UDPConn.LocalAddr() +func (u *GenericConn) LocalAddr() (*Addr, error) { + a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: @@ -55,11 +57,11 @@ func (uc *Conn) LocalAddr() (*Addr, error) { } } -func (u *Conn) ReloadConfig(c *config.C) { +func (u *GenericConn) ReloadConfig(c *config.C) { // TODO } -func NewUDPStatsEmitter(udpConns []*Conn) func() { +func NewUDPStatsEmitter(udpConns []Conn) func() { // No UDP stats for non-linux return func() {} } @@ -68,7 +70,7 @@ type rawMessage struct { Len uint32 } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) buffer := make([]byte, MTU) h := &header.H{} @@ -80,12 +82,12 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 5d4b16a28..1151c8906 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -20,8 +20,9 @@ import ( //TODO: make it support reload as best you can! -type Conn struct { +type StdConn struct { sysFd int + isV4 bool l *logrus.Logger batch int } @@ -45,9 +46,22 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +func maybeIPV4(ip net.IP) (net.IP, bool) { + ip4 := ip.To4() + if ip4 != nil { + return ip4, true + } + return ip, false +} + +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + ipV4, isV4 := maybeIPV4(ip) + af := unix.AF_INET6 + if isV4 { + af = unix.AF_INET + } syscall.ForkLock.RLock() - fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { unix.CloseOnExec(fd) } @@ -58,9 +72,6 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) ( return nil, fmt.Errorf("unable to open socket: %s", err) } - var lip [16]byte - copy(lip[:], net.ParseIP(ip)) - if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) @@ -68,7 +79,17 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) ( } //TODO: support multiple listening IPs (for limiting ipv6) - if err = unix.Bind(fd, &unix.SockaddrInet6{Addr: lip, Port: port}); err != nil { + var sa unix.Sockaddr + if isV4 { + sa4 := &unix.SockaddrInet4{Port: port} + copy(sa4.Addr[:], ipV4) + sa = sa4 + } else { + sa6 := &unix.SockaddrInet6{Port: port} + copy(sa6.Addr[:], ip.To16()) + sa = sa6 + } + if err = unix.Bind(fd, sa); err != nil { return nil, fmt.Errorf("unable to bind to socket: %s", err) } @@ -77,30 +98,30 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &Conn{sysFd: fd, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err } -func (u *Conn) Rebind() error { +func (u *StdConn) Rebind() error { return nil } -func (u *Conn) SetRecvBuffer(n int) error { +func (u *StdConn) SetRecvBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } -func (u *Conn) SetSendBuffer(n int) error { +func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } -func (u *Conn) GetRecvBuffer() (int, error) { +func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } -func (u *Conn) GetSendBuffer() (int, error) { +func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *Conn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (*Addr, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { return nil, err @@ -119,7 +140,7 @@ func (u *Conn) LocalAddr() (*Addr, error) { return addr, nil } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -137,20 +158,24 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall for { n, err := read(msgs) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } //metric.Update(int64(n)) for i := 0; i < n; i++ { - udpAddr.IP = names[i][8:24] + if u.isV4 { + udpAddr.IP = names[i][4:8] + } else { + udpAddr.IP = names[i][8:24] + } udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } } -func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) { +func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMSG, @@ -171,7 +196,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) { } } -func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) { +func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMMSG, @@ -191,14 +216,19 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *Conn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, addr *Addr) error { + if u.isV4 { + return u.writeTo4(b, addr) + } + return u.writeTo6(b, addr) +} +func (u *StdConn) writeTo6(b []byte, addr *Addr) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&rsa.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(rsa.Addr[:], addr.IP) + // Little Endian -> Network Endian + rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) + copy(rsa.Addr[:], addr.IP.To16()) for { _, _, err := unix.Syscall6( @@ -221,7 +251,40 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error { } } -func (u *Conn) ReloadConfig(c *config.C) { +func (u *StdConn) writeTo4(b []byte, addr *Addr) error { + addrV4, isAddrV4 := maybeIPV4(addr.IP) + if !isAddrV4 { + return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + + var rsa unix.RawSockaddrInet4 + rsa.Family = unix.AF_INET + // Little Endian -> Network Endian + rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) + copy(rsa.Addr[:], addrV4) + + for { + _, _, err := unix.Syscall6( + unix.SYS_SENDTO, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unix.SizeofSockaddrInet4), + ) + + if err != 0 { + return &net.OpError{Op: "sendto", Err: err} + } + + //TODO: handle incomplete writes + + return nil + } +} + +func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { err := u.SetRecvBuffer(b) @@ -253,7 +316,7 @@ func (u *Conn) ReloadConfig(c *config.C) { } } -func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error { +func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { var vallen uint32 = 4 * _SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { @@ -262,11 +325,16 @@ func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error { return nil } -func NewUDPStatsEmitter(udpConns []*Conn) func() { +func (u *StdConn) Close() error { + //TODO: this will not interrupt the read loop + return syscall.Close(u.sysFd) +} + +func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge var meminfo _SK_MEMINFO - if err := udpConns[0].getMemInfo(&meminfo); err == nil { + if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ @@ -285,7 +353,7 @@ func NewUDPStatsEmitter(udpConns []*Conn) func() { return func() { for i, gauges := range udpGauges { - if err := udpConns[i].getMemInfo(&meminfo); err == nil { + if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { for j := 0; j < _SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 06cd38224..523968c23 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,7 +30,7 @@ type rawMessage struct { Len uint32 } -func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index c442405b6..a54f1dfd9 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,7 +33,7 @@ type rawMessage struct { Pad0 [4]byte } -func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) diff --git a/udp/udp_freebsd.go b/udp/udp_netbsd.go similarity index 77% rename from udp/udp_freebsd.go rename to udp/udp_netbsd.go index 10ff94b40..3c14face3 100644 --- a/udp/udp_freebsd.go +++ b/udp/udp_netbsd.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -36,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go new file mode 100644 index 000000000..31c1a554c --- /dev/null +++ b/udp/udp_rio_windows.go @@ -0,0 +1,403 @@ +//go:build !e2e_testing +// +build !e2e_testing + +// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go + +package udp + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/conn/winrio" +) + +// Assert we meet the standard conn interface +var _ Conn = &RIOConn{} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr windows.RawSockaddrInet6 + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +type RIOConn struct { + isOpen atomic.Bool + l *logrus.Logger + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + results [packetsPerRing]winrio.Result +} + +func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { + if !winrio.Initialize() { + return nil, errors.New("could not initialize winrio") + } + + u := &RIOConn{l: l} + + addr := [16]byte{} + copy(addr[:], ip.To16()) + err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + if err != nil { + return nil, fmt.Errorf("bind: %w", err) + } + + for i := 0; i < packetsPerRing; i++ { + err = u.insertReceiveRequest() + if err != nil { + return nil, fmt.Errorf("init rx ring: %w", err) + } + } + + u.isOpen.Store(true) + return u, nil +} + +func (u *RIOConn) bind(sa windows.Sockaddr) error { + var err error + u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return err + } + + // Enable v4 for this socket + syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + + err = u.rx.Open() + if err != nil { + return err + } + + err = u.tx.Open() + if err != nil { + return err + } + + u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0) + if err != nil { + return err + } + + err = windows.Bind(u.sock, sa) + if err != nil { + return err + } + + return nil +} + +func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + buffer := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + udpAddr := &Addr{IP: make([]byte, 16)} + nb := make([]byte, 12, 12) + + for { + // Just read one packet at a time + n, rua, err := u.receive(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpAddr.IP = rua.Addr[:] + p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) + p[0] = byte(rua.Port >> 8) + p[1] = byte(rua.Port) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + } +} + +func (u *RIOConn) insertReceiveRequest() error { + packet := u.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + u.rx.mu.Lock() + defer u.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result + +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + procyield(1) + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + } + + if count == 0 { + err = winrio.Notify(u.rx.cq) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + if count == 0 { + return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress + + } + } + + u.rx.Return(1) + err = u.insertReceiveRequest() + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + goto retry + } + + if results[0].Status != 0 { + return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status) + } + + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, ep, nil +} + +func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { + if !u.isOpen.Load() { + return net.ErrClosed + } + + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + + u.tx.mu.Lock() + defer u.tx.mu.Unlock() + + count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 && u.tx.isFull { + err := winrio.Notify(u.tx.cq) + if err != nil { + return err + } + + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + + if !u.isOpen.Load() { + return net.ErrClosed + } + + count = winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + + if count > 0 { + u.tx.Return(count) + } + + packet := u.tx.Push() + packet.addr.Family = windows.AF_INET6 + p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) + p[0] = byte(addr.Port >> 8) + p[1] = byte(addr.Port) + copy(packet.addr.Addr[:], addr.IP.To16()) + copy(packet.data[:], buf) + + dataBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets), + Length: uint32(len(buf)), + } + + addressBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (u *RIOConn) LocalAddr() (*Addr, error) { + sa, err := windows.Getsockname(u.sock) + if err != nil { + return nil, err + } + + v6 := sa.(*windows.SockaddrInet6) + return &Addr{ + IP: v6.Addr[:], + Port: uint16(v6.Port), + }, nil +} + +func (u *RIOConn) Rebind() error { + return nil +} + +func (u *RIOConn) ReloadConfig(*config.C) {} + +func (u *RIOConn) Close() error { + if !u.isOpen.CompareAndSwap(true, false) { + return nil + } + + windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) + + u.rx.CloseAndZero() + u.tx.CloseAndZero() + if u.sock != 0 { + windows.CloseHandle(u.sock) + } + return nil +} + +func (ring *ringBuffer) Push() *ringPacket { + for ring.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + ring.tail += 1 + if ring.tail%packetsPerRing == ring.head%packetsPerRing { + ring.isFull = true + } + return ret +} + +func (ring *ringBuffer) Return(count uint32) { + if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull { + return + } + ring.head += count + ring.isFull = false +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + + return nil +} diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55213b831..55985f47f 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -5,7 +5,9 @@ package udp import ( "fmt" + "io" "net" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -36,18 +38,19 @@ func (u *Packet) Copy() *Packet { return n } -type Conn struct { +type TesterConn struct { Addr *Addr RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - l *logrus.Logger + closed atomic.Bool + l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) { - return &Conn{ - Addr: &Addr{net.ParseIP(ip), uint16(port)}, +func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { + return &TesterConn{ + Addr: &Addr{ip, uint16(port)}, RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -57,16 +60,20 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, e // Send will place a UdpPacket onto the receive queue for nebula to consume // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send -func (u *Conn) Send(packet *Packet) { +func (u *TesterConn) Send(packet *Packet) { + if u.closed.Load() { + return + } + h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) } - if u.l.Level >= logrus.InfoLevel { + if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("dataLen", len(packet.Data)). - Info("UDP receiving injected packet") + Debug("UDP receiving injected packet") } u.RxPackets <- packet } @@ -74,7 +81,7 @@ func (u *Conn) Send(packet *Packet) { // Get will pull a UdpPacket from the transmit queue // nebula meant to send this message on the network, it will be encrypted // packets were ingested from the tun side (in most cases), you can send them with Tun.Send -func (u *Conn) Get(block bool) *Packet { +func (u *TesterConn) Get(block bool) *Packet { if block { return <-u.TxPackets } @@ -91,7 +98,11 @@ func (u *Conn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *Conn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { + if u.closed.Load() { + return io.ErrClosedPipe + } + p := &Packet{ Data: make([]byte, len(b), len(b)), FromIp: make([]byte, 16), @@ -108,7 +119,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error { return nil } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -122,21 +133,29 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } ua.Port = p.FromPort copy(ua.IP, p.FromIp.To16()) - r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } -func (u *Conn) ReloadConfig(*config.C) {} +func (u *TesterConn) ReloadConfig(*config.C) {} -func NewUDPStatsEmitter(_ []*Conn) func() { +func NewUDPStatsEmitter(_ []Conn) func() { // No UDP stats for non-linux return func() {} } -func (u *Conn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (*Addr, error) { return u.Addr, nil } -func (u *Conn) Rebind() error { +func (u *TesterConn) Rebind() error { + return nil +} + +func (u *TesterConn) Close() error { + if u.closed.CompareAndSwap(false, true) { + close(u.RxPackets) + close(u.TxPackets) + } return nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1f2ce6475..ebcace670 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -3,14 +3,31 @@ package udp -// Windows support is primarily implemented in udp_generic, besides NewListenConfig - import ( "fmt" "net" "syscall" + + "github.com/sirupsen/logrus" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + if multi { + //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level + // The udp stack would need to be reworked to hide away the implementation differences between + // Windows and Linux + return nil, fmt.Errorf("multiple udp listeners not supported on windows") + } + + rc, err := NewRIOListener(l, ip, port) + if err == nil { + return rc, nil + } + + l.WithError(err).Error("Falling back to standard udp sockets") + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -24,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/util/error.go b/util/error.go index 7f9bc4792..d7710f9ab 100644 --- a/util/error.go +++ b/util/error.go @@ -2,6 +2,7 @@ package util import ( "errors" + "fmt" "github.com/sirupsen/logrus" ) @@ -12,18 +13,38 @@ type ContextualError struct { Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { - return ContextualError{Context: msg, Fields: fields, RealError: realError} +func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { + return &ContextualError{Context: msg, Fields: fields, RealError: realError} } -func (ce ContextualError) Error() string { +// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one +func ContextualizeIfNeeded(msg string, err error) error { + switch err.(type) { + case *ContextualError: + return err + default: + return NewContextualError(msg, nil, err) + } +} + +// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError +func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { + switch v := err.(type) { + case *ContextualError: + v.Log(l) + default: + l.WithError(err).Error(msg) + } +} + +func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context } - return ce.RealError.Error() + return fmt.Errorf("%s (%v): %w", ce.Context, ce.Fields, ce.RealError).Error() } -func (ce ContextualError) Unwrap() error { +func (ce *ContextualError) Unwrap() error { if ce.RealError == nil { return errors.New(ce.Context) } diff --git a/util/error_test.go b/util/error_test.go index 747d04e0c..5041f82ce 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -2,6 +2,7 @@ package util import ( "errors" + "fmt" "testing" "github.com/sirupsen/logrus" @@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) { e.Log(l) assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) } + +func TestLogWithContextIfNeeded(t *testing.T) { + l := logrus.New() + l.Formatter = &logrus.TextFormatter{ + DisableTimestamp: true, + DisableColors: true, + } + + tl := NewTestLogWriter() + l.Out = tl + + // Test ignoring fallback context + tl.Reset() + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + LogWithContextIfNeeded("This should get thrown away", e, l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + + // Test using fallback context + tl.Reset() + err := fmt.Errorf("this is a normal error") + LogWithContextIfNeeded("Fallback context woo", err, l) + assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) +} + +func TestContextualizeIfNeeded(t *testing.T) { + // Test ignoring fallback context + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e)) + + // Test using fallback context + err := fmt.Errorf("this is a normal error") + cErr := ContextualizeIfNeeded("Fallback context woo", err) + + switch v := cErr.(type) { + case *ContextualError: + assert.Equal(t, err, v.RealError) + default: + t.Error("Error was not wrapped") + t.Fail() + } +}