From 13de0bb11a1b1b7a3f6a18f5c203a71680247cf1 Mon Sep 17 00:00:00 2001 From: Sam Lock Date: Thu, 7 Mar 2024 16:45:37 +0000 Subject: [PATCH] feat: Add log ingest client (#153) Signed-off-by: Sam Lock --- Makefile | 1 + {bundle => base}/auth.go | 2 +- base/client.go | 89 ++++++ base/client_conf.go | 100 +++++++ {bundle => base}/interceptors.go | 2 +- base/validate.go | 40 +++ bundle/client.go | 81 +----- bundle/client_conf.go | 88 +----- bundle/client_test.go | 35 +-- logcap/client.go | 58 ++++ logcap/client_conf.go | 10 + logcap/client_test.go | 272 ++++++++++++++++++ .../logsv1connect/CerbosLogsServiceHandler.go | 102 +++++++ 13 files changed, 712 insertions(+), 168 deletions(-) rename {bundle => base}/auth.go (99%) create mode 100644 base/client.go create mode 100644 base/client_conf.go rename {bundle => base}/interceptors.go (99%) create mode 100644 base/validate.go create mode 100644 logcap/client.go create mode 100644 logcap/client_conf.go create mode 100644 logcap/client_test.go create mode 100644 test/mocks/genpb/cerbos/cloud/logs/v1/logsv1connect/CerbosLogsServiceHandler.go diff --git a/Makefile b/Makefile index 42caec0..60a7abd 100644 --- a/Makefile +++ b/Makefile @@ -37,6 +37,7 @@ generate-proto-code: proto-gen-deps generate-mocks: $(MOCKERY) @ $(MOCKERY) $(MOCK_QUIET) --srcpkg=./genpb/cerbos/cloud/apikey/v1/apikeyv1connect --name=ApiKeyServiceHandler @ $(MOCKERY) $(MOCK_QUIET) --srcpkg=./genpb/cerbos/cloud/bundle/v1/bundlev1connect --name=CerbosBundleServiceHandler + @ $(MOCKERY) $(MOCK_QUIET) --srcpkg=./genpb/cerbos/cloud/logs/v1/logsv1connect --name=CerbosLogsServiceHandler .PHONY: compile compile: diff --git a/bundle/auth.go b/base/auth.go similarity index 99% rename from bundle/auth.go rename to base/auth.go index 6ed73d9..e62baed 100644 --- a/bundle/auth.go +++ b/base/auth.go @@ -1,6 +1,6 @@ // Copyright 2021-2024 Zenauth Ltd. // SPDX-License-Identifier: Apache-2.0 -package bundle +package base import ( "context" diff --git a/base/client.go b/base/client.go new file mode 100644 index 0000000..5787a70 --- /dev/null +++ b/base/client.go @@ -0,0 +1,89 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 +package base + +import ( + "encoding/json" + "fmt" + "net/http" + + "connectrpc.com/connect" + "connectrpc.com/otelconnect" + "github.com/go-logr/logr" + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/net/http2" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +type Client struct { + HTTPClient *http.Client + conf ClientConf +} + +func NewClient(conf ClientConf) (c Client, opts []connect.ClientOption, _ error) { + otelConnect, err := otelconnect.NewInterceptor() + if err != nil { + return c, opts, fmt.Errorf("failed to create otel interceptor: %w", err) + } + + opts = []connect.ClientOption{ + connect.WithCompressMinBytes(1024), + connect.WithInterceptors( + otelConnect, + newUserAgentInterceptor(), + ), + } + + retryableHTTPClient := mkRetryableHTTPClient(conf) + authClient := newAuthClient(conf, retryableHTTPClient, opts...) + + opts = append(opts, connect.WithInterceptors(newAuthInterceptor(authClient))) + + return Client{ + conf: conf, + HTTPClient: retryableHTTPClient, + }, opts, nil +} + +func MkHTTPClient(conf ClientConf) *http.Client { + return &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: conf.TLS.Clone(), + }, + } +} + +func mkRetryableHTTPClient(conf ClientConf) *http.Client { + httpClient := retryablehttp.NewClient() + httpClient.HTTPClient = MkHTTPClient(conf) + httpClient.RetryMax = conf.RetryMaxAttempts + httpClient.RetryWaitMin = conf.RetryWaitMin + httpClient.RetryWaitMax = conf.RetryWaitMax + httpClient.Logger = logWrapper{Logger: conf.Logger.WithName("transport")} + + return httpClient.StandardClient() +} + +func LogResponsePayload(log logr.Logger, payload proto.Message) { + if lg := log.V(3); lg.Enabled() { + lg.Info("RPC response", "payload", ProtoWrapper{p: payload}) + } +} + +type ProtoWrapper struct { + p proto.Message +} + +func NewProtoWrapper(msg proto.Message) ProtoWrapper { + return ProtoWrapper{p: msg} +} + +func (pw ProtoWrapper) MarshalLog() any { + bytes, err := protojson.Marshal(pw.p) + if err != nil { + return fmt.Sprintf("error marshaling response: %v", err) + } + + return json.RawMessage(bytes) +} diff --git a/base/client_conf.go b/base/client_conf.go new file mode 100644 index 0000000..96d2dd4 --- /dev/null +++ b/base/client_conf.go @@ -0,0 +1,100 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package base + +import ( + "crypto/tls" + "errors" + "fmt" + "time" + + "github.com/go-logr/logr" + "go.uber.org/multierr" + + "github.com/cerbos/cloud-api/credentials" + pdpv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/pdp/v1" +) + +var ( + errEmptyAPIEndpoint = errors.New("api endpoint must be defined") + errEmptyBootstrapEndpoint = errors.New("bootstrap endpoint must be defined") + errHeartbeatIntervalTooShort = errors.New("heartbeat interval is too short") + errMissingCredentials = errors.New("missing credentials") + errMissingIdentifier = errors.New("missing PDP identifier") +) + +const ( + defaultHeartbeatInterval = 2 * time.Minute + defaultRetryWaitMin = 1 * time.Second //nolint:revive + defaultRetryWaitMax = 5 * time.Minute + defaultRetryMaxAttempts = 10 + minHeartbeatInterval = 30 * time.Second +) + +type ClientConf struct { + PDPIdentifier *pdpv1.Identifier + TLS *tls.Config + Logger logr.Logger + Credentials *credentials.Credentials + APIEndpoint string + BootstrapEndpoint string + RetryWaitMin time.Duration + RetryWaitMax time.Duration + RetryMaxAttempts int + HeartbeatInterval time.Duration +} + +func (cc ClientConf) Validate() (outErr error) { + if cc.Credentials == nil { + outErr = multierr.Append(outErr, errMissingCredentials) + } + + if cc.APIEndpoint == "" { + outErr = multierr.Append(outErr, errEmptyAPIEndpoint) + } + + if cc.BootstrapEndpoint == "" { + outErr = multierr.Append(outErr, errEmptyBootstrapEndpoint) + } + + if cc.PDPIdentifier == nil { + outErr = multierr.Append(outErr, errMissingIdentifier) + } else if err := Validate(cc.PDPIdentifier); err != nil { + outErr = multierr.Append(outErr, fmt.Errorf("invalid PDP identifier: %w", err)) + } + + if cc.HeartbeatInterval > 0 && cc.HeartbeatInterval < minHeartbeatInterval { + outErr = multierr.Append(outErr, errHeartbeatIntervalTooShort) + } + + return outErr +} + +func (cc *ClientConf) SetDefaults() { + if cc.RetryMaxAttempts == 0 { + cc.RetryMaxAttempts = defaultRetryMaxAttempts + } + + if cc.RetryWaitMin == 0 { + cc.RetryWaitMin = defaultRetryWaitMin + } + + if cc.RetryWaitMax == 0 { + cc.RetryWaitMax = defaultRetryWaitMax + } + + if cc.HeartbeatInterval == 0 { + cc.HeartbeatInterval = defaultHeartbeatInterval + } +} + +type logWrapper struct { + logr.Logger +} + +func (lw logWrapper) Printf(msg string, args ...any) { + if log := lw.V(1); log.Enabled() { + log.Info(fmt.Sprintf(msg, args...)) + } +} diff --git a/bundle/interceptors.go b/base/interceptors.go similarity index 99% rename from bundle/interceptors.go rename to base/interceptors.go index d4e26c0..718de13 100644 --- a/bundle/interceptors.go +++ b/base/interceptors.go @@ -1,7 +1,7 @@ // Copyright 2021-2024 Zenauth Ltd. // SPDX-License-Identifier: Apache-2.0 -package bundle +package base import ( "context" diff --git a/base/validate.go b/base/validate.go new file mode 100644 index 0000000..44fd24f --- /dev/null +++ b/base/validate.go @@ -0,0 +1,40 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package base + +import ( + "fmt" + "sync" + + "github.com/bufbuild/protovalidate-go" + "google.golang.org/protobuf/proto" + + bootstrapv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bootstrap/v1" +) + +var ( + validateFn func(proto.Message) error + validatorOnce sync.Once +) + +func Validate[T proto.Message](obj T) error { + validatorOnce.Do(func() { + validator, err := protovalidate.New( + protovalidate.WithMessages( + &bootstrapv1.PDPConfig{}, + ), + ) + if err != nil { + validateFn = func(_ proto.Message) error { + return fmt.Errorf("failed to initialize validator: %w", err) + } + } else { + validateFn = func(m proto.Message) error { + return validator.Validate(m) + } + } + }) + + return validateFn(obj) +} diff --git a/bundle/client.go b/bundle/client.go index dc2ed36..bc4e417 100644 --- a/bundle/client.go +++ b/bundle/client.go @@ -6,7 +6,6 @@ package bundle import ( "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -20,18 +19,15 @@ import ( "unicode" "connectrpc.com/connect" - "connectrpc.com/otelconnect" "github.com/go-logr/logr" - "github.com/hashicorp/go-retryablehttp" "github.com/minio/sha256-simd" "github.com/rogpeppe/go-internal/cache" "github.com/sourcegraph/conc/pool" "go.uber.org/multierr" - "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/cerbos/cloud-api/base" bootstrapv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bootstrap/v1" bundlev1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bundle/v1" "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bundle/v1/bundlev1connect" @@ -105,11 +101,10 @@ type ClientEvent struct { } type Client struct { - authClient *authClient rpcClient bundlev1connect.CerbosBundleServiceClient - httpClient *http.Client bundleCache *cache.Cache conf ClientConf + base.Client } func NewClient(conf ClientConf) (*Client, error) { @@ -117,36 +112,24 @@ func NewClient(conf ClientConf) (*Client, error) { return nil, err } - bcache, err := mkBundleCache(conf.CacheDir) + baseClient, options, err := base.NewClient(conf.ClientConf) if err != nil { return nil, err } - otelConnect, err := otelconnect.NewInterceptor() + bcache, err := mkBundleCache(conf.CacheDir) if err != nil { - return nil, fmt.Errorf("failed to create otel interceptor: %w", err) - } - - httpClient := mkHTTPClient(conf) // Bidi streams don't work with retryable HTTP client. - retryableHTTPClient := mkRetryableHTTPClient(conf) - options := []connect.ClientOption{ - connect.WithCompressMinBytes(1024), - connect.WithInterceptors( - otelConnect, - newUserAgentInterceptor(), - ), + return nil, err } - authClient := newAuthClient(conf, retryableHTTPClient, options...) - options = append(options, connect.WithInterceptors(newAuthInterceptor(authClient))) + httpClient := base.MkHTTPClient(conf.ClientConf) // Bidi streams don't work with retryable HTTP client. rpcClient := bundlev1connect.NewCerbosBundleServiceClient(httpClient, conf.APIEndpoint, options...) return &Client{ + Client: baseClient, bundleCache: bcache, conf: conf, - authClient: authClient, rpcClient: rpcClient, - httpClient: retryableHTTPClient, }, nil } @@ -173,25 +156,6 @@ func mkBundleCache(path string) (*cache.Cache, error) { return c, nil } -func mkHTTPClient(conf ClientConf) *http.Client { - return &http.Client{ - Transport: &http2.Transport{ - TLSClientConfig: conf.TLS.Clone(), - }, - } -} - -func mkRetryableHTTPClient(conf ClientConf) *http.Client { - httpClient := retryablehttp.NewClient() - httpClient.HTTPClient = mkHTTPClient(conf) - httpClient.RetryMax = conf.RetryMaxAttempts - httpClient.RetryWaitMin = conf.RetryWaitMin - httpClient.RetryWaitMax = conf.RetryWaitMax - httpClient.Logger = logWrapper{Logger: conf.Logger.WithName("transport")} - - return httpClient.StandardClient() -} - func (c *Client) BootstrapBundle(ctx context.Context, bundleLabel string) (string, error) { log := c.conf.Logger.WithValues("bundle", bundleLabel) log.V(1).Info("Getting bootstrap configuration") @@ -213,7 +177,7 @@ func (c *Client) BootstrapBundle(ctx context.Context, bundleLabel string) (strin log.Info("Bootstrap configuration downloaded", "created_at", meta.CreatedAt.AsTime(), "commit_hash", meta.CommitHash) } - logResponsePayload(log, bootstrapConf) + base.LogResponsePayload(log, bootstrapConf) return c.getBundleFile(logr.NewContext(ctx, log), bootstrapConf.BundleInfo) } @@ -226,7 +190,7 @@ func (c *Client) downloadBootstrapConf(ctx context.Context, url string) (*bootst } log.V(1).Info("Sending download request") - resp, err := c.httpClient.Do(req) + resp, err := c.HTTPClient.Do(req) if err != nil { log.V(1).Error(err, "Failed to send download request") return nil, fmt.Errorf("failed to send download request: %w", err) @@ -296,7 +260,7 @@ func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, err return "", err } - logResponsePayload(log, resp.Msg) + base.LogResponsePayload(log, resp.Msg) return c.getBundleFile(logr.NewContext(ctx, log), resp.Msg.BundleInfo) } @@ -396,7 +360,7 @@ func (c *Client) watchStreamRecv(stream *connect.BidiStreamForClient[bundlev1.Wa } processMsg := func(msg *bundlev1.WatchBundleResponse) error { - logResponsePayload(log, msg) + base.LogResponsePayload(log, msg) switch m := msg.Msg.(type) { case *bundlev1.WatchBundleResponse_BundleUpdate: @@ -435,7 +399,7 @@ func (c *Client) watchStreamRecv(stream *connect.BidiStreamForClient[bundlev1.Wa return ReconnectError{Backoff: backoff} default: - log.V(2).Info("Ignoring unknown message", "msg", protoWrapper{p: msg}) + log.V(2).Info("Ignoring unknown message", "msg", base.NewProtoWrapper(msg)) } return nil @@ -567,12 +531,6 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev1.Wa } } -func logResponsePayload(log logr.Logger, payload proto.Message) { - if lg := log.V(3); lg.Enabled() { - lg.Info("RPC response", "payload", protoWrapper{p: payload}) - } -} - func (c *Client) getBundleFile(ctx context.Context, binfo *bundlev1.BundleInfo) (outPath string, outErr error) { log := logr.FromContextOrDiscard(ctx) @@ -670,7 +628,7 @@ func (c *Client) doDownloadSegment(ctx context.Context, cacheKey cache.ActionID, } log.V(1).Info("Sending download request") - resp, err := c.httpClient.Do(req) + resp, err := c.HTTPClient.Do(req) if err != nil { log.V(1).Error(err, "Failed to send download request") if r.size() > 1 && attempt < maxDownloadAttempts { @@ -841,16 +799,3 @@ func (r *ring) next() string { func (r *ring) size() int { return len(r.elements) } - -type protoWrapper struct { - p proto.Message -} - -func (pw protoWrapper) MarshalLog() any { - bytes, err := protojson.Marshal(pw.p) - if err != nil { - return fmt.Sprintf("error marshaling response: %v", err) - } - - return json.RawMessage(bytes) -} diff --git a/bundle/client_conf.go b/bundle/client_conf.go index d3abddf..eb3bd15 100644 --- a/bundle/client_conf.go +++ b/bundle/client_conf.go @@ -4,67 +4,23 @@ package bundle import ( - "crypto/tls" - "errors" "fmt" "os" - "time" - "github.com/go-logr/logr" "go.uber.org/multierr" - "github.com/cerbos/cloud-api/credentials" - pdpv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/pdp/v1" -) - -var ( - errEmptyAPIEndpoint = errors.New("api endpoint must be defined") - errEmptyBootstrapEndpoint = errors.New("bootstrap endpoint must be defined") - errHeartbeatIntervalTooShort = errors.New("heartbeat interval is too short") - errMissingCredentials = errors.New("missing credentials") - errMissingIdentifier = errors.New("missing PDP identifier") -) - -const ( - defaultHeartbeatInterval = 2 * time.Minute - defaultRetryWaitMin = 1 * time.Second //nolint:revive - defaultRetryWaitMax = 5 * time.Minute - defaultRetryMaxAttempts = 10 - minHeartbeatInterval = 30 * time.Second + "github.com/cerbos/cloud-api/base" ) type ClientConf struct { - PDPIdentifier *pdpv1.Identifier - TLS *tls.Config - Logger logr.Logger - Credentials *credentials.Credentials - APIEndpoint string - BootstrapEndpoint string - CacheDir string - TempDir string - RetryWaitMin time.Duration - RetryWaitMax time.Duration - RetryMaxAttempts int - HeartbeatInterval time.Duration + CacheDir string + TempDir string + base.ClientConf } func (cc ClientConf) Validate() (outErr error) { - if cc.Credentials == nil { - outErr = multierr.Append(outErr, errMissingCredentials) - } - - if cc.APIEndpoint == "" { - outErr = multierr.Append(outErr, errEmptyAPIEndpoint) - } - - if cc.BootstrapEndpoint == "" { - outErr = multierr.Append(outErr, errEmptyBootstrapEndpoint) - } - - if cc.PDPIdentifier == nil { - outErr = multierr.Append(outErr, errMissingIdentifier) - } else if err := Validate(cc.PDPIdentifier); err != nil { - outErr = multierr.Append(outErr, fmt.Errorf("invalid PDP identifier: %w", err)) + if err := cc.ClientConf.Validate(); err != nil { + outErr = multierr.Append(outErr, err) } if cc.CacheDir != "" { @@ -79,31 +35,9 @@ func (cc ClientConf) Validate() (outErr error) { } } - if cc.HeartbeatInterval > 0 && cc.HeartbeatInterval < minHeartbeatInterval { - outErr = multierr.Append(outErr, errHeartbeatIntervalTooShort) - } - return outErr } -func (cc *ClientConf) SetDefaults() { - if cc.RetryMaxAttempts == 0 { - cc.RetryMaxAttempts = defaultRetryMaxAttempts - } - - if cc.RetryWaitMin == 0 { - cc.RetryWaitMin = defaultRetryWaitMin - } - - if cc.RetryWaitMax == 0 { - cc.RetryWaitMax = defaultRetryWaitMax - } - - if cc.HeartbeatInterval == 0 { - cc.HeartbeatInterval = defaultHeartbeatInterval - } -} - func validateDir(path string) error { stat, err := os.Stat(path) if err != nil { @@ -116,13 +50,3 @@ func validateDir(path string) error { return nil } - -type logWrapper struct { - logr.Logger -} - -func (lw logWrapper) Printf(msg string, args ...any) { - if log := lw.V(1); log.Enabled() { - log.Info(fmt.Sprintf(msg, args...)) - } -} diff --git a/bundle/client_test.go b/bundle/client_test.go index cf1a57b..90a1884 100644 --- a/bundle/client_test.go +++ b/bundle/client_test.go @@ -45,6 +45,7 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/cerbos/cloud-api/base" "github.com/cerbos/cloud-api/bundle" "github.com/cerbos/cloud-api/credentials" apikeyv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1" @@ -61,7 +62,7 @@ const testPrivateKey = "CERBOS-1MKYX97DHPT3B-L05ALANNYUXY7HEMFXUNQRLS47D8G8D9ZYU var pdpIdentifer = &pdpv1.Identifier{ Instance: "instance", - Version: "0.19.0", + Version: "0.34.0", } func TestBootstrapBundle(t *testing.T) { @@ -551,7 +552,7 @@ func TestGetBundle(t *testing.T) { _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) - require.ErrorIs(t, err, bundle.ErrAuthenticationFailed) + require.ErrorIs(t, err, base.ErrAuthenticationFailed) }) } @@ -802,7 +803,7 @@ func TestWatchBundle(t *testing.T) { _, err := client.WatchBundle(context.Background(), "label") require.Error(t, err) - require.ErrorIs(t, err, bundle.ErrAuthenticationFailed) + require.ErrorIs(t, err, base.ErrAuthenticationFailed) }) } @@ -1079,17 +1080,19 @@ func mkClient(t *testing.T, url string, cert *x509.Certificate) (*bundle.Client, require.NoError(t, err, "Failed to create credentials") conf := bundle.ClientConf{ - Credentials: creds, - BootstrapEndpoint: url, - APIEndpoint: url, - PDPIdentifier: pdpIdentifer, - RetryWaitMin: 10 * time.Millisecond, - RetryWaitMax: 30 * time.Millisecond, - RetryMaxAttempts: 2, - CacheDir: cacheDir, - TempDir: tempDir, - Logger: testr.NewWithOptions(t, testr.Options{Verbosity: 4}), - TLS: tlsConf, + ClientConf: base.ClientConf{ + Credentials: creds, + BootstrapEndpoint: url, + APIEndpoint: url, + PDPIdentifier: pdpIdentifer, + RetryWaitMin: 10 * time.Millisecond, + RetryWaitMax: 30 * time.Millisecond, + RetryMaxAttempts: 2, + Logger: testr.NewWithOptions(t, testr.Options{Verbosity: 4}), + TLS: tlsConf, + }, + CacheDir: cacheDir, + TempDir: tempDir, } client, err := bundle.NewClient(conf) @@ -1228,7 +1231,7 @@ type authCheck struct{} func (ac authCheck) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { - if req.Header().Get(bundle.AuthTokenHeader) != "access-token" { + if req.Header().Get(base.AuthTokenHeader) != "access-token" { return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("invalid or missing access token")) } return next(ctx, req) @@ -1241,7 +1244,7 @@ func (ac authCheck) WrapStreamingClient(next connect.StreamingClientFunc) connec func (ac authCheck) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error { - if conn.RequestHeader().Get(bundle.AuthTokenHeader) != "access-token" { + if conn.RequestHeader().Get(base.AuthTokenHeader) != "access-token" { return connect.NewError(connect.CodeUnauthenticated, errors.New("invalid or missing access token")) } diff --git a/logcap/client.go b/logcap/client.go new file mode 100644 index 0000000..dbcd975 --- /dev/null +++ b/logcap/client.go @@ -0,0 +1,58 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package logcap + +import ( + "context" + + "connectrpc.com/connect" + + "github.com/cerbos/cloud-api/base" + logsv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/logs/v1" + "github.com/cerbos/cloud-api/genpb/cerbos/cloud/logs/v1/logsv1connect" +) + +type Client struct { + base.Client + rpcClient logsv1connect.CerbosLogsServiceClient + conf ClientConf +} + +func NewClient(conf ClientConf) (*Client, error) { + if err := conf.Validate(); err != nil { + return nil, err + } + + baseClient, options, err := base.NewClient(conf.ClientConf) + if err != nil { + return nil, err + } + + httpClient := base.MkHTTPClient(conf.ClientConf) // Bidi streams don't work with retryable HTTP client. + rpcClient := logsv1connect.NewCerbosLogsServiceClient(httpClient, conf.APIEndpoint, options...) + + return &Client{ + Client: baseClient, + conf: conf, + rpcClient: rpcClient, + }, nil +} + +func (c *Client) Ingest(ctx context.Context, batch *logsv1.IngestBatch) error { + log := c.conf.Logger + log.V(1).Info("Calling Ingest RPC") + + resp, err := c.rpcClient.Ingest(ctx, connect.NewRequest(&logsv1.IngestRequest{ + PdpId: c.conf.PDPIdentifier, + Batch: batch, + })) + if err != nil { + log.Error(err, "Ingest RPC failed") + return err + } + + base.LogResponsePayload(log, resp.Msg) + + return nil +} diff --git a/logcap/client_conf.go b/logcap/client_conf.go new file mode 100644 index 0000000..dcf79c6 --- /dev/null +++ b/logcap/client_conf.go @@ -0,0 +1,10 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package logcap + +import "github.com/cerbos/cloud-api/base" + +type ClientConf struct { + base.ClientConf +} diff --git a/logcap/client_test.go b/logcap/client_test.go new file mode 100644 index 0000000..104deb2 --- /dev/null +++ b/logcap/client_test.go @@ -0,0 +1,272 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +//go:build tests +// +build tests + +package logcap_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "connectrpc.com/connect" + "connectrpc.com/grpcreflect" + auditv1 "github.com/cerbos/cerbos/api/genpb/cerbos/audit/v1" + effectv1 "github.com/cerbos/cerbos/api/genpb/cerbos/effect/v1" + enginev1 "github.com/cerbos/cerbos/api/genpb/cerbos/engine/v1" + "github.com/go-logr/logr/testr" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/cerbos/cloud-api/base" + "github.com/cerbos/cloud-api/credentials" + apikeyv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1" + "github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1/apikeyv1connect" + logsv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/logs/v1" + "github.com/cerbos/cloud-api/genpb/cerbos/cloud/logs/v1/logsv1connect" + pdpv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/pdp/v1" + "github.com/cerbos/cloud-api/logcap" + mockapikeyv1connect "github.com/cerbos/cloud-api/test/mocks/genpb/cerbos/cloud/apikey/v1/apikeyv1connect" + mocklogsv1connect "github.com/cerbos/cloud-api/test/mocks/genpb/cerbos/cloud/logs/v1/logsv1connect" +) + +const testPrivateKey = "CERBOS-1MKYX97DHPT3B-L05ALANNYUXY7HEMFXUNQRLS47D8G8D9ZYUMEDPE4X2382Q2WMSSXY2G2A" + +var pdpIdentifer = &pdpv1.Identifier{ + Instance: "instance", + Version: "0.34.0", +} + +func TestIngest(t *testing.T) { + mockAPIKeySvc := mockapikeyv1connect.NewApiKeyServiceHandler(t) + mockLogsSvc := mocklogsv1connect.NewCerbosLogsServiceHandler(t) + server := startTestServer(t, mockAPIKeySvc, mockLogsSvc) + t.Cleanup(server.Close) + + client := mkClient(t, server.URL, server.Certificate()) + + t.Run("Success", func(t *testing.T) { + mockAPIKeySvc.EXPECT(). + IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())). + Return(connect.NewResponse(&apikeyv1.IssueAccessTokenResponse{ + AccessToken: "access-token", + ExpiresIn: durationpb.New(1 * time.Minute), + }), nil) + + now := time.Now() + + batch := &logsv1.IngestBatch{ + Id: "foo", + Entries: []*logsv1.IngestBatch_Entry{ + { + Kind: logsv1.IngestBatch_ENTRY_KIND_ACCESS_LOG, + Timestamp: ×tamppb.Timestamp{}, + Entry: &logsv1.IngestBatch_Entry_AccessLogEntry{ + AccessLogEntry: &auditv1.AccessLogEntry{ + CallId: "1", + Timestamp: timestamppb.New(now.Add(time.Duration(1) * time.Second)), + Peer: &auditv1.Peer{ + Address: "1.1.1.1", + }, + Metadata: map[string]*auditv1.MetaValues{}, + Method: "/cerbos.svc.v1.CerbosService/Check", + }, + }, + }, + { + Kind: logsv1.IngestBatch_ENTRY_KIND_DECISION_LOG, + Timestamp: ×tamppb.Timestamp{}, + Entry: &logsv1.IngestBatch_Entry_DecisionLogEntry{ + DecisionLogEntry: &auditv1.DecisionLogEntry{ + CallId: "2", + Timestamp: timestamppb.New(now.Add(time.Duration(2) * time.Second)), + Inputs: []*enginev1.CheckInput{ + { + RequestId: "2", + Resource: &enginev1.Resource{ + Kind: "test:kind", + Id: "test", + }, + Principal: &enginev1.Principal{ + Id: "test", + Roles: []string{"a", "b"}, + }, + Actions: []string{"a1", "a2"}, + }, + }, + Outputs: []*enginev1.CheckOutput{ + { + RequestId: "2", + ResourceId: "test", + Actions: map[string]*enginev1.CheckOutput_ActionEffect{ + "a1": {Effect: effectv1.Effect_EFFECT_ALLOW, Policy: "resource.test.v1"}, + "a2": {Effect: effectv1.Effect_EFFECT_ALLOW, Policy: "resource.test.v1"}, + }, + }, + }, + }, + }, + }, + }, + } + + want := &logsv1.IngestRequest{ + PdpId: pdpIdentifer, + Batch: batch, + } + + mockLogsSvc.EXPECT(). + Ingest(mock.Anything, mock.MatchedBy(func(c *connect.Request[logsv1.IngestRequest]) bool { + return cmp.Diff(c.Msg, want, protocmp.Transform()) == "" + })). + Return(connect.NewResponse(&logsv1.IngestResponse{ + Status: &logsv1.IngestResponse_Success{}, + }), nil).Once() + + err := client.Ingest(context.Background(), batch) + require.NoError(t, err) + }) + + t.Run("AuthenticationFailure", func(t *testing.T) { + mockAPIKeySvc := mockapikeyv1connect.NewApiKeyServiceHandler(t) + mockLogsSvc := mocklogsv1connect.NewCerbosLogsServiceHandler(t) + server := startTestServer(t, mockAPIKeySvc, mockLogsSvc) + t.Cleanup(server.Close) + + client := mkClient(t, server.URL, server.Certificate()) + + mockAPIKeySvc.EXPECT(). + IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())). + Return(nil, connect.NewError(connect.CodeUnauthenticated, errors.New("🙅"))) + + err := client.Ingest(context.Background(), &logsv1.IngestBatch{}) + require.Error(t, err) + require.ErrorIs(t, err, base.ErrAuthenticationFailed) + }) +} + +func startTestServer(t *testing.T, mockAPIKeySvc apikeyv1connect.ApiKeyServiceHandler, mockLogsSvc logsv1connect.CerbosLogsServiceHandler) *httptest.Server { + t.Helper() + + compress1KB := connect.WithCompressMinBytes(1024) + apiKeyPath, apiKeySvcHandler := apikeyv1connect.NewApiKeyServiceHandler(mockAPIKeySvc, compress1KB) + logsPath, logsSvcHandler := logsv1connect.NewCerbosLogsServiceHandler(mockLogsSvc, connect.WithInterceptors(authCheck{}), compress1KB) + + logRequests := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("REQUEST: %s", r.URL) + h.ServeHTTP(w, r) + }) + } + mux := http.NewServeMux() + mux.Handle(apiKeyPath, logRequests(apiKeySvcHandler)) + mux.Handle(logsPath, logRequests(logsSvcHandler)) + mux.Handle(grpcreflect.NewHandlerV1( + grpcreflect.NewStaticReflector(logsv1connect.CerbosLogsServiceName), + compress1KB, + )) + mux.Handle(grpcreflect.NewHandlerV1Alpha( + grpcreflect.NewStaticReflector(logsv1connect.CerbosLogsServiceName), + compress1KB, + )) + + s := httptest.NewUnstartedServer(h2c.NewHandler(mux, &http2.Server{})) + s.EnableHTTP2 = true + s.StartTLS() + + return s +} + +func mkClient(t *testing.T, url string, cert *x509.Certificate) *logcap.Client { + t.Helper() + + var tlsConf *tls.Config + if cert != nil { + certPool := x509.NewCertPool() + certPool.AddCert(cert) + tlsConf = &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519, + }, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + NextProtos: []string{"h2"}, + RootCAs: certPool, + } + } + + creds, err := credentials.New("client-id", "client-secret", testPrivateKey) + require.NoError(t, err, "Failed to create credentials") + + conf := logcap.ClientConf{ + ClientConf: base.ClientConf{ + Credentials: creds, + BootstrapEndpoint: url, + APIEndpoint: url, + PDPIdentifier: pdpIdentifer, + RetryWaitMin: 10 * time.Millisecond, + RetryWaitMax: 30 * time.Millisecond, + RetryMaxAttempts: 2, + Logger: testr.NewWithOptions(t, testr.Options{Verbosity: 4}), + TLS: tlsConf, + }, + } + + client, err := logcap.NewClient(conf) + require.NoError(t, err, "Failed to create client") + + return client +} + +type authCheck struct{} + +func (ac authCheck) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if req.Header().Get(base.AuthTokenHeader) != "access-token" { + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("invalid or missing access token")) + } + return next(ctx, req) + }) +} + +func (ac authCheck) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return next +} + +func (ac authCheck) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error { + if conn.RequestHeader().Get(base.AuthTokenHeader) != "access-token" { + return connect.NewError(connect.CodeUnauthenticated, errors.New("invalid or missing access token")) + } + + return next(ctx, conn) + }) +} + +func issueAccessTokenRequest() func(*connect.Request[apikeyv1.IssueAccessTokenRequest]) bool { + return func(req *connect.Request[apikeyv1.IssueAccessTokenRequest]) bool { + return req.Msg.ClientId == "client-id" && req.Msg.ClientSecret == "client-secret" + } +} diff --git a/test/mocks/genpb/cerbos/cloud/logs/v1/logsv1connect/CerbosLogsServiceHandler.go b/test/mocks/genpb/cerbos/cloud/logs/v1/logsv1connect/CerbosLogsServiceHandler.go new file mode 100644 index 0000000..374ca59 --- /dev/null +++ b/test/mocks/genpb/cerbos/cloud/logs/v1/logsv1connect/CerbosLogsServiceHandler.go @@ -0,0 +1,102 @@ +// Copyright 2021-2024 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery v2.42.0. DO NOT EDIT. + +package mocklogsv1connect + +import ( + context "context" + + connect "connectrpc.com/connect" + + logsv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/logs/v1" + + mock "github.com/stretchr/testify/mock" +) + +// CerbosLogsServiceHandler is an autogenerated mock type for the CerbosLogsServiceHandler type +type CerbosLogsServiceHandler struct { + mock.Mock +} + +type CerbosLogsServiceHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *CerbosLogsServiceHandler) EXPECT() *CerbosLogsServiceHandler_Expecter { + return &CerbosLogsServiceHandler_Expecter{mock: &_m.Mock} +} + +// Ingest provides a mock function with given fields: _a0, _a1 +func (_m *CerbosLogsServiceHandler) Ingest(_a0 context.Context, _a1 *connect.Request[logsv1.IngestRequest]) (*connect.Response[logsv1.IngestResponse], error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Ingest") + } + + var r0 *connect.Response[logsv1.IngestResponse] + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *connect.Request[logsv1.IngestRequest]) (*connect.Response[logsv1.IngestResponse], error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *connect.Request[logsv1.IngestRequest]) *connect.Response[logsv1.IngestResponse]); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*connect.Response[logsv1.IngestResponse]) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *connect.Request[logsv1.IngestRequest]) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CerbosLogsServiceHandler_Ingest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ingest' +type CerbosLogsServiceHandler_Ingest_Call struct { + *mock.Call +} + +// Ingest is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *connect.Request[logsv1.IngestRequest] +func (_e *CerbosLogsServiceHandler_Expecter) Ingest(_a0 interface{}, _a1 interface{}) *CerbosLogsServiceHandler_Ingest_Call { + return &CerbosLogsServiceHandler_Ingest_Call{Call: _e.mock.On("Ingest", _a0, _a1)} +} + +func (_c *CerbosLogsServiceHandler_Ingest_Call) Run(run func(_a0 context.Context, _a1 *connect.Request[logsv1.IngestRequest])) *CerbosLogsServiceHandler_Ingest_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*connect.Request[logsv1.IngestRequest])) + }) + return _c +} + +func (_c *CerbosLogsServiceHandler_Ingest_Call) Return(_a0 *connect.Response[logsv1.IngestResponse], _a1 error) *CerbosLogsServiceHandler_Ingest_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CerbosLogsServiceHandler_Ingest_Call) RunAndReturn(run func(context.Context, *connect.Request[logsv1.IngestRequest]) (*connect.Response[logsv1.IngestResponse], error)) *CerbosLogsServiceHandler_Ingest_Call { + _c.Call.Return(run) + return _c +} + +// NewCerbosLogsServiceHandler creates a new instance of CerbosLogsServiceHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCerbosLogsServiceHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *CerbosLogsServiceHandler { + mock := &CerbosLogsServiceHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}