diff --git a/internal/authn/preshared/authn.go b/internal/authn/preshared/authn.go index cb954aaac..0184ce1e8 100644 --- a/internal/authn/preshared/authn.go +++ b/internal/authn/preshared/authn.go @@ -48,17 +48,3 @@ func (a *KeyAuthn) Authenticate(ctx context.Context) error { } return status.Error(codes.Unauthenticated, base.ErrorCode_ERROR_CODE_INVALID_KEY.String()) } - -// Get Request Metadata - gets the current request metadata, refreshing tokens -// if required -func (a *KeyAuthn) GetRequestMetadata(_ context.Context, uri ...string) (map[string]string, error) { - return map[string]string{ - "Authorization": "Bearer " + "test", - }, nil -} - -// RequireTransportSecurity indicates whether the credentials requires -// transport security. -func (a *KeyAuthn) RequireTransportSecurity() bool { - return true -} diff --git a/internal/engines/balancer/balancer.go b/internal/engines/balancer/balancer.go index 1f553a42f..6e4660f53 100644 --- a/internal/engines/balancer/balancer.go +++ b/internal/engines/balancer/balancer.go @@ -12,7 +12,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "github.com/Permify/permify/internal/authn/preshared" "github.com/Permify/permify/internal/config" "github.com/Permify/permify/internal/engines" "github.com/Permify/permify/internal/invoke" @@ -34,22 +33,27 @@ type Balancer struct { options []grpc.DialOption } -// NewCheckEngineWithBalancer -// struct with the provided cache.Cache instance. +// NewCheckEngineWithBalancer creates a new check engine with a load balancer. +// It takes a Check interface, SchemaReader, distributed config, gRPC config, and authn config as input. +// It returns a Check interface and an error if any. func NewCheckEngineWithBalancer( + ctx context.Context, checker invoke.Check, schemaReader storage.SchemaReader, dst *config.Distributed, srv *config.GRPC, authn *config.Authn, ) (invoke.Check, error) { - var err error - - var options []grpc.DialOption - - var creds credentials.TransportCredentials + var ( + creds credentials.TransportCredentials + options []grpc.DialOption + isSecure bool + err error + ) + // Set up TLS credentials if paths are provided if srv.TLSConfig.CertPath != "" && srv.TLSConfig.KeyPath != "" { + isSecure = true creds, err = credentials.NewClientTLSFromFile(srv.TLSConfig.CertPath, srv.TLSConfig.KeyPath) if err != nil { return nil, fmt.Errorf("could not load TLS certificate: %s", err) @@ -58,22 +62,26 @@ func NewCheckEngineWithBalancer( creds = insecure.NewCredentials() } - // TODO: Add client-side authentication using a key from KeyAuthn. - // 1. Initialize the KeyAuthn structure using the provided configuration. - // 2. Convert the KeyAuthn instance into PerRPCCredentials. - // 3. Append grpc.WithPerRPCCredentials() to the options slice. - createPresharedKeyAuthN, err := preshared.NewKeyAuthn(context.Background(), authn.Preshared) - if err != nil { - return nil, fmt.Errorf("could not create authentication key: %s", err) - } - + // Append common options options = append( options, grpc.WithDefaultServiceConfig(grpcServicePolicy), grpc.WithTransportCredentials(creds), - grpc.WithPerRPCCredentials(createPresharedKeyAuthN), ) + // Handle authentication if enabled + if authn != nil && authn.Enabled { + token, err := setupAuthn(ctx, authn) + if err != nil { + return nil, err + } + if isSecure { + options = append(options, grpc.WithPerRPCCredentials(secureTokenCredentials{"authorization": "Bearer " + token})) + } else { + options = append(options, grpc.WithPerRPCCredentials(nonSecureTokenCredentials{"authorization": "Bearer " + token})) + } + } + conn, err := grpc.Dial(dst.Address, options...) if err != nil { return nil, err diff --git a/internal/engines/balancer/utils.go b/internal/engines/balancer/utils.go new file mode 100644 index 000000000..3ffceb757 --- /dev/null +++ b/internal/engines/balancer/utils.go @@ -0,0 +1,96 @@ +package balancer + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/Permify/permify/internal/config" +) + +// secureTokenCredentials represents a map used for storing secure tokens. +// These tokens require transport security. +type secureTokenCredentials map[string]string + +// RequireTransportSecurity indicates that transport security is required for these credentials. +func (c secureTokenCredentials) RequireTransportSecurity() bool { + return true // Transport security is required for secure tokens. +} + +// GetRequestMetadata retrieves the current metadata (secure tokens) for a request. +func (c secureTokenCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return c, nil // Returns the secure tokens as metadata with no error. +} + +// nonSecureTokenCredentials represents a map used for storing non-secure tokens. +// These tokens do not require transport security. +type nonSecureTokenCredentials map[string]string + +// RequireTransportSecurity indicates that transport security is not required for these credentials. +func (c nonSecureTokenCredentials) RequireTransportSecurity() bool { + return false // Transport security is not required for non-secure tokens. +} + +// GetRequestMetadata retrieves the current metadata (non-secure tokens) for a request. +func (c nonSecureTokenCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + return c, nil // Returns the non-secure tokens as metadata with no error. +} + +// OIDCTokenResponse represents the response from the OIDC token endpoint +type OIDCTokenResponse struct { + AccessToken string `json:"access_token"` +} + +func getOIDCToken(ctx context.Context, issuer, clientID string) (string, error) { + // Prepare the request data + data := url.Values{} + data.Set("client_id", clientID) + data.Set("grant_type", "client_credentials") + + // Create the request + req, err := http.NewRequestWithContext(ctx, "POST", issuer+"/token", strings.NewReader(data.Encode())) + if err != nil { + return "", fmt.Errorf("error creating token request: %v", err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + // Send the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error sending token request: %v", err) + } + defer resp.Body.Close() + + // Decode the response + var tokenResponse OIDCTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return "", fmt.Errorf("error decoding token response: %v", err) + } + + return tokenResponse.AccessToken, nil +} + +// setupAuthn configures the authentication token based on the provided authentication method. +// It returns the token string and an error if any. +func setupAuthn(ctx context.Context, authn *config.Authn) (string, error) { + var token string + var err error + + switch authn.Method { + case "preshared": + token = authn.Preshared.Keys[0] + case "oidc": + token, err = getOIDCToken(ctx, authn.Oidc.Issuer, authn.Oidc.ClientID) + if err != nil { + return "", fmt.Errorf("failed to get OIDC token: %s", err) + } + default: + return "", fmt.Errorf("unknown authentication method: '%s'", authn.Method) + } + + return token, nil +} diff --git a/pkg/cmd/serve.go b/pkg/cmd/serve.go index ce6e9077a..04ccde835 100644 --- a/pkg/cmd/serve.go +++ b/pkg/cmd/serve.go @@ -236,6 +236,7 @@ func serve() func(cmd *cobra.Command, args []string) error { // Create the checker either with load balancing or caching capabilities. if cfg.Distributed.Enabled { checker, err = balancer.NewCheckEngineWithBalancer( + context.Background(), checkEngine, schemaReader, &cfg.Distributed,