Skip to content

Commit

Permalink
enhancement: Use singleton base client
Browse files Browse the repository at this point in the history
Avoid authenticating multiple times for different APIs.

Signed-off-by: Charith Ellawala <[email protected]>
  • Loading branch information
charithe committed Jul 12, 2024
1 parent eecad8b commit 654b2f7
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 88 deletions.
11 changes: 10 additions & 1 deletion base/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"connectrpc.com/connect"
"github.com/go-logr/logr"

apikeyv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1"
"github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1/apikeyv1connect"
Expand All @@ -24,9 +25,10 @@ const (
var ErrAuthenticationFailed = errors.New("failed to authenticate: invalid credentials")

type authClient struct {
accessToken string
expiresAt time.Time
apiKeyClient apikeyv1connect.ApiKeyServiceClient
logger logr.Logger
accessToken string
clientID string
clientSecret string
mutex sync.RWMutex
Expand All @@ -37,12 +39,14 @@ func newAuthClient(conf ClientConf, httpClient *http.Client, clientOptions ...co
apiKeyClient: apikeyv1connect.NewApiKeyServiceClient(httpClient, conf.APIEndpoint, clientOptions...),
clientID: conf.Credentials.ClientID,
clientSecret: conf.Credentials.ClientSecret,
logger: conf.Logger.WithName("auth"),
}
}

func (a *authClient) SetAuthTokenHeader(ctx context.Context, headers http.Header) error {
accessToken, err := a.authenticate(ctx)
if err != nil {
a.logger.V(1).Error(err, "Failed to authenticate")
return err
}

Expand All @@ -55,6 +59,7 @@ func (a *authClient) authenticate(ctx context.Context) (string, error) {
accessToken, ok := a.currentAccessToken()
a.mutex.RUnlock()
if ok {
a.logger.V(4).Info("Using existing token")
return accessToken, nil
}

Expand All @@ -63,14 +68,17 @@ func (a *authClient) authenticate(ctx context.Context) (string, error) {

accessToken, ok = a.currentAccessToken()
if ok {
a.logger.V(4).Info("Using existing token")
return accessToken, nil
}

a.logger.V(4).Info("Obtaining new access token")
response, err := a.apiKeyClient.IssueAccessToken(ctx, connect.NewRequest(&apikeyv1.IssueAccessTokenRequest{
ClientId: a.clientID,
ClientSecret: a.clientSecret,
}))
if err != nil {
a.logger.V(1).Error(err, "Failed to authenticate")
if connect.CodeOf(err) == connect.CodeUnauthenticated {
return "", ErrAuthenticationFailed
}
Expand All @@ -84,6 +92,7 @@ func (a *authClient) authenticate(ctx context.Context) (string, error) {

a.accessToken = response.Msg.AccessToken
a.expiresAt = time.Now().Add(expiresIn)
a.logger.V(4).Info("Obtained new access token")

return a.accessToken, nil
}
Expand Down
18 changes: 14 additions & 4 deletions base/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (
"golang.org/x/net/http2"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

"github.com/cerbos/cloud-api/credentials"
)

type Client struct {
HTTPClient *http.Client
conf ClientConf
ClientConf
}

func NewClient(conf ClientConf) (c Client, opts []connect.ClientOption, _ error) {
Expand All @@ -41,12 +43,20 @@ func NewClient(conf ClientConf) (c Client, opts []connect.ClientOption, _ error)
opts = append(opts, connect.WithInterceptors(newAuthInterceptor(authClient)))

return Client{
conf: conf,
ClientConf: conf,
HTTPClient: retryableHTTPClient,
}, opts, nil
}

func MkHTTPClient(conf ClientConf) *http.Client {
func (c Client) StdHTTPClient() *http.Client {
return mkHTTPClient(c.ClientConf)
}

func (c Client) HubCredentials() *credentials.Credentials {
return c.Credentials
}

func mkHTTPClient(conf ClientConf) *http.Client {
return &http.Client{
Transport: &http2.Transport{
TLSClientConfig: conf.TLS.Clone(),
Expand All @@ -56,7 +66,7 @@ func MkHTTPClient(conf ClientConf) *http.Client {

func mkRetryableHTTPClient(conf ClientConf) *http.Client {
httpClient := retryablehttp.NewClient()
httpClient.HTTPClient = MkHTTPClient(conf)
httpClient.HTTPClient = mkHTTPClient(conf)
httpClient.RetryMax = conf.RetryMaxAttempts
httpClient.RetryWaitMin = conf.RetryWaitMin
httpClient.RetryWaitMax = conf.RetryWaitMax
Expand Down
35 changes: 15 additions & 20 deletions bundle/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,18 @@ type Client struct {
base.Client
}

func NewClient(conf ClientConf) (*Client, error) {
func NewClient(conf ClientConf, baseClient base.Client, options []connect.ClientOption) (*Client, error) {
if err := conf.Validate(); err != nil {
return nil, err
}

baseClient, options, err := base.NewClient(conf.ClientConf)
if err != nil {
return nil, err
}

bcache, err := mkBundleCache(conf.CacheDir)
if err != nil {
return nil, err
}

httpClient := base.MkHTTPClient(conf.ClientConf) // Bidi streams don't work with retryable HTTP client.
rpcClient := bundlev1connect.NewCerbosBundleServiceClient(httpClient, conf.APIEndpoint, options...)
httpClient := baseClient.StdHTTPClient() // Bidi streams don't work with retryable HTTP client.
rpcClient := bundlev1connect.NewCerbosBundleServiceClient(httpClient, baseClient.APIEndpoint, options...)

return &Client{
Client: baseClient,
Expand Down Expand Up @@ -157,12 +152,12 @@ func mkBundleCache(path string) (*cache.Cache, error) {
}

func (c *Client) BootstrapBundle(ctx context.Context, bundleLabel string) (string, error) {
log := c.conf.Logger.WithValues("bundle", bundleLabel)
log := c.Logger.WithValues("bundle", bundleLabel)
log.V(1).Info("Getting bootstrap configuration")

wsID := c.conf.Credentials.HashString(c.conf.Credentials.WorkspaceID)
labelHash := c.conf.Credentials.HashString(bundleLabel)
bootstrapURL, err := url.JoinPath(c.conf.BootstrapEndpoint, bootstrapPathPrefix, wsID, labelHash)
wsID := c.Credentials.HashString(c.Credentials.WorkspaceID)
labelHash := c.Credentials.HashString(bundleLabel)
bootstrapURL, err := url.JoinPath(c.BootstrapEndpoint, bootstrapPathPrefix, wsID, labelHash)
if err != nil {
return "", fmt.Errorf("failed to construct bootstrap URL: %w", err)
}
Expand Down Expand Up @@ -214,7 +209,7 @@ func (c *Client) downloadBootstrapConf(ctx context.Context, url string) (*bootst
return nil, errDownloadFailed
}

confData, err := c.conf.Credentials.Decrypt(io.LimitReader(resp.Body, maxBootstrapSize))
confData, err := c.Credentials.Decrypt(io.LimitReader(resp.Body, maxBootstrapSize))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -248,11 +243,11 @@ func (c *Client) parseBootstrapConf(input io.Reader) (*bootstrapv1.PDPConfig, er

// GetBundle returns the path to the bundle with the given label.
func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, error) {
log := c.conf.Logger.WithValues("bundle", bundleLabel)
log := c.Logger.WithValues("bundle", bundleLabel)
log.V(1).Info("Calling GetBundle RPC")

resp, err := c.rpcClient.GetBundle(ctx, connect.NewRequest(&bundlev1.GetBundleRequest{
PdpId: c.conf.PDPIdentifier,
PdpId: c.PDPIdentifier,
BundleLabel: bundleLabel,
}))
if err != nil {
Expand All @@ -266,7 +261,7 @@ func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, err
}

func (c *Client) WatchBundle(ctx context.Context, bundleLabel string) (WatchHandle, error) {
log := c.conf.Logger.WithValues("bundle", bundleLabel)
log := c.Logger.WithValues("bundle", bundleLabel)
log.V(1).Info("Calling WatchBundle RPC")

stream := c.rpcClient.WatchBundle(ctx)
Expand Down Expand Up @@ -451,8 +446,8 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev1.Wa
var ticker *time.Ticker
var tickerChan <-chan time.Time

if c.conf.HeartbeatInterval > 0 {
ticker = time.NewTicker(c.conf.HeartbeatInterval)
if c.HeartbeatInterval > 0 {
ticker = time.NewTicker(c.HeartbeatInterval)
tickerChan = ticker.C
} else {
log.V(1).Info("Regular heartbeats disabled")
Expand All @@ -475,7 +470,7 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev1.Wa

log.V(2).Info("Initiating bundle watch")
if err := stream.Send(&bundlev1.WatchBundleRequest{
PdpId: c.conf.PDPIdentifier,
PdpId: c.PDPIdentifier,
Msg: &bundlev1.WatchBundleRequest_WatchLabel_{
WatchLabel: &bundlev1.WatchBundleRequest_WatchLabel{BundleLabel: wh.bundleLabel},
},
Expand All @@ -487,7 +482,7 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev1.Wa
sendHeartbeat := func(activeBundleID string) error {
log.V(3).Info("Sending heartbeat", "active_bundle_id", activeBundleID)
if err := stream.Send(&bundlev1.WatchBundleRequest{
PdpId: c.conf.PDPIdentifier,
PdpId: c.PDPIdentifier,
Msg: &bundlev1.WatchBundleRequest_Heartbeat_{
Heartbeat: &bundlev1.WatchBundleRequest_Heartbeat{
Timestamp: timestamppb.Now(),
Expand Down
7 changes: 0 additions & 7 deletions bundle/client_conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ import (
"os"

"go.uber.org/multierr"

"github.com/cerbos/cloud-api/base"
)

type ClientConf struct {
CacheDir string
TempDir string
base.ClientConf
}

func (cc ClientConf) Validate() (outErr error) {
if err := cc.ClientConf.Validate(); err != nil {
outErr = multierr.Append(outErr, err)
}

if cc.CacheDir != "" {
if err := validateDir(cc.CacheDir); err != nil {
outErr = multierr.Append(outErr, err)
Expand Down
31 changes: 16 additions & 15 deletions bundle/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
bundlev1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bundle/v1"
"github.com/cerbos/cloud-api/genpb/cerbos/cloud/bundle/v1/bundlev1connect"
pdpv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/pdp/v1"
"github.com/cerbos/cloud-api/hub"
mockapikeyv1connect "github.com/cerbos/cloud-api/test/mocks/genpb/cerbos/cloud/apikey/v1/apikeyv1connect"
mockbundlev1connect "github.com/cerbos/cloud-api/test/mocks/genpb/cerbos/cloud/bundle/v1/bundlev1connect"
)
Expand Down Expand Up @@ -1079,23 +1080,23 @@ func mkClient(t *testing.T, url string, cert *x509.Certificate) (*bundle.Client,
creds, err := credentials.New("client-id", "client-secret", testPrivateKey)
require.NoError(t, err, "Failed to create credentials")

conf := bundle.ClientConf{
ClientConf: base.ClientConf{
Credentials: creds,
BootstrapEndpoint: url,
APIEndpoint: url,
PDPIdentifier: pdpIdentifer,
RetryWaitMin: 10 * time.Millisecond,
RetryWaitMax: 30 * time.Millisecond,
RetryMaxAttempts: 2,
Logger: testr.NewWithOptions(t, testr.Options{Verbosity: 4}),
TLS: tlsConf,
},
h, err := hub.New(base.ClientConf{
Credentials: creds,
BootstrapEndpoint: url,
APIEndpoint: url,
PDPIdentifier: pdpIdentifer,
RetryWaitMin: 10 * time.Millisecond,
RetryWaitMax: 30 * time.Millisecond,
RetryMaxAttempts: 2,
Logger: testr.NewWithOptions(t, testr.Options{Verbosity: 4}),
TLS: tlsConf,
})
require.NoError(t, err, "Failed to initialize hub")

client, err := h.BundleClient(bundle.ClientConf{
CacheDir: cacheDir,
TempDir: tempDir,
}

client, err := bundle.NewClient(conf)
})
require.NoError(t, err, "Failed to create client")

return client, creds
Expand Down
70 changes: 70 additions & 0 deletions hub/hub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2021-2024 Zenauth Ltd.
// SPDX-License-Identifier: Apache-2.0

package hub

import (
"fmt"
"sync"

"connectrpc.com/connect"

"github.com/cerbos/cloud-api/base"
"github.com/cerbos/cloud-api/bundle"
"github.com/cerbos/cloud-api/logcap"
)

var (
mu sync.RWMutex
instance *Hub
)

type Hub struct {
opts []connect.ClientOption
client base.Client
}

func Get(conf base.ClientConf) (*Hub, error) {
mu.RLock()
i := instance
mu.RUnlock()

if i != nil {
return i, nil
}

mu.Lock()
defer mu.Unlock()

if instance != nil {
return instance, nil
}

i, err := New(conf)
if err != nil {
return nil, fmt.Errorf("failed to create base Hub client: %w", err)
}

instance = i
return instance, nil
}

func New(conf base.ClientConf) (*Hub, error) {
client, opts, err := base.NewClient(conf)
if err != nil {
return nil, fmt.Errorf("failed to create base Hub client: %w", err)
}

return &Hub{
client: client,
opts: opts,
}, nil
}

func (h *Hub) BundleClient(conf bundle.ClientConf) (*bundle.Client, error) {
return bundle.NewClient(conf, h.client, h.opts)
}

func (h *Hub) LogCapClient() (*logcap.Client, error) {
return logcap.NewClient(h.client, h.opts)
}
Loading

0 comments on commit 654b2f7

Please sign in to comment.