diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..495b5de --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +goimports -l -w . # includes go fmt +golangci-lint run # includes golint, go vet diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..692dc78 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,49 @@ +name: Benchmark Performance +on: + pull_request: + branches: + - master + - main + push: + branches: + - master + - main +permissions: + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write + +jobs: + benchmark: + name: Performance regression check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: "stable" + - name: Install dependencies + run: go mod tidy + - name: Run benchmark + run: go test ./... -bench=. -benchmem -count 2 -timeout 1m | tee benchmarks.txt + # Download previous benchmark result from cache (if exists) + - name: Download previous benchmark data + uses: actions/cache@v4 + with: + path: ./cache + key: ${{ runner.os }}-benchmark + # Run `github-action-benchmark` action + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Go Benchmark + tool: 'go' + output-file-path: benchmarks.txt + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '200%' + comment-on-alert: true + fail-on-alert: true + alert-comment-cc-users: '@ndyakov' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 499a5bf..e167232 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: - name: Install dependencies run: go mod tidy - name: Run tests with coverage - run: go test ./... -coverprofile=./cover.out -covermode=atomic -coverpkg=./... + run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 5m - name: Upload coverage uses: actions/upload-artifact@v4 with: diff --git a/.gitignore b/.gitignore index 8455110..ce0bc89 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,8 @@ *.tar.gz *.dic coverage.txt +cover.out **/coverage.txt +**/cover.out .vscode +tmp/ diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..cd7f519 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,7 @@ +version: "2" +run: + tests: false +linters: + disable: + - depguard + diff --git a/.testcoverage.yml b/.testcoverage.yml index fa9a770..3c88348 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -11,15 +11,15 @@ profile: cover.out threshold: # (optional; default 0) # Minimum coverage percentage required for individual files. - file: 70 + file: 85 # (optional; default 0) # Minimum coverage percentage required for each package. - package: 80 + package: 85 # (optional; default 0) # Minimum overall project coverage percentage required. - total: 95 + total: 90 # Holds regexp rules which will override thresholds for matched files or packages # using their paths. @@ -28,9 +28,9 @@ threshold: # new threshold to it. If project has multiple rules that match same path, # override rules should be listed in order from specific to more general rules. override: - # Increase coverage threshold to 100% for `foo` package - # (default is 80, as configured above in this example). - - path: ^pkg/lib/foo$ + - path: ^internal$ + threshold: 95 + - path: ^token$ threshold: 100 # Holds regexp rules which will exclude matched files or packages diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9947ff1 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,103 @@ +# Contributing to go-redis-entraid + +We welcome contributions from the community! If you'd like to contribute to this project, please follow these guidelines: + +## Getting Started + +1. Fork the repository +2. Create a new branch for your feature or bugfix +3. Make your changes +4. Run the tests and ensure they pass +5. Submit a pull request + +## Development Setup + +```bash +# Clone your fork +git clone https://github.com/your-username/go-redis-entraid.git +cd go-redis-entraid + +# Install dependencies +go mod download + +# Run tests +go test ./... +``` + +## Code Style and Standards + +- Follow the Go standard formatting (`go fmt`) +- Write clear and concise commit messages +- Include tests for new features +- Update documentation as needed +- Follow the existing code style and patterns + +## Testing + +We maintain high test coverage for the project. When contributing: + +- Add tests for new features +- Ensure existing tests pass +- Run the test coverage tool: + ```bash + go test -coverprofile=cover.out ./... + go tool cover -html=cover.out + ``` + +## Pull Request Process + +1. Ensure your code passes all tests +2. Update the README.md if necessary +3. Submit your pull request with a clear description of the changes + +## Reporting Issues + +If you find a bug or have a feature request: + +1. Check the existing issues to avoid duplicates +2. Create a new issue with: + - A clear title and description + - Steps to reproduce (for bugs) + - Expected and actual behavior + - Environment details (Go version, OS, etc.) + +## Development Workflow + +1. Create a new branch for your feature/fix: + ```bash + git checkout -b feature/your-feature-name + ``` + +2. Make your changes and commit them: + ```bash + git add . + git commit -m "Description of your changes" + ``` + +3. Push your changes to your fork: + ```bash + git push origin feature/your-feature-name + ``` + +4. Create a pull request from your fork to the main repository + +## Review Process + +- All pull requests will be reviewed by maintainers +- Be prepared to make changes based on feedback +- Ensure your code meets the project's standards +- Address any CI/CD failures + +## Documentation + +- Update relevant documentation when making changes +- Include examples for new features +- Update the README if necessary +- Add comments to complex code sections + +## Questions? + +If you have any questions about contributing, please: +1. Check the existing documentation +2. Look through existing issues +3. Create a new issue if your question hasn't been answered \ No newline at end of file diff --git a/README.md b/README.md index e5b9399..ff184ef 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,793 @@ # go-redis-entraid Entra ID extension for go-redis + +## Table of Contents +- [Introduction](#introduction) +- [Quick Start](#quick-start) +- [Architecture Overview](#architecture-overview) +- [Authentication Providers](#authentication-providers) +- [Configuration Guide](#configuration-guide) +- [Examples](#examples) +- [Testing](#testing) +- [FAQ](#faq) + +## Introduction + +go-redis-entraid is a Go library that provides Entra ID (formerly Azure AD) authentication support for Redis Enterprise Cloud. It enables secure authentication using various Entra ID identity types and manages token lifecycle automatically. + +### Version Compatibility +- Go: 1.16+ +- Redis: 6.0+ +- Azure Entra ID: Latest + +### Key Features +- Support for multiple Entra ID identity types +- Automatic token refresh and management +- Configurable token refresh policies +- Retry mechanisms with exponential backoff +- Thread-safe token management +- Streaming credentials provider interface + +## Quick Start + +### Minimal Example +Here's the simplest way to get started: + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + "github.com/redis-developer/go-redis-entraid/entraid" + "github.com/redis/go-redis/v9" +) + +func main() { + // Get required environment variables + clientID := os.Getenv("AZURE_CLIENT_ID") + redisEndpoint := os.Getenv("REDIS_ENDPOINT") + if clientID == "" || redisEndpoint == "" { + log.Fatal("AZURE_CLIENT_ID and REDIS_ENDPOINT environment variables are required") + } + + // Create credentials provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: clientID, + }, + }) + if err != nil { + log.Fatalf("Failed to create credentials provider: %v", err) + } + + // Create Redis client + client := redis.NewClient(&redis.Options{ + Addr: redisEndpoint, + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + log.Println("Connected to Redis!") +} +``` + +### Environment Setup +```bash +# Required environment variables +export AZURE_CLIENT_ID="your-client-id" +export REDIS_ENDPOINT="your-redis-endpoint:6380" + +# Optional environment variables +export AZURE_TENANT_ID="your-tenant-id" +export AZURE_CLIENT_SECRET="your-client-secret" +export AZURE_AUTHORITY_HOST="https://login.microsoftonline.com" # For custom authority +``` + +### Running the Example +```bash +go mod init your-app +go get github.com/redis-developer/go-redis-entraid +go run main.go +``` + +## Architecture Overview + +### Component Diagram +```mermaid +graph TD + A[Redis Client] --> B[StreamingCredentialsProvider] + B --> C[Token Manager] + C --> D[Identity Provider] + D --> E[Azure Entra ID] + + subgraph "Token Management" + C --> F[Token Cache] + C --> G[Token Refresh] + C --> H[Error Handling] + end + + subgraph "Identity Providers" + D --> I[Managed Identity] + D --> J[Confidential Client] + D --> K[Default Azure Identity] + D --> L[Custom Provider] + end +``` + +### Token Lifecycle +```mermaid +sequenceDiagram + participant Client + participant TokenManager + participant IdentityProvider + participant Azure + + Client->>TokenManager: GetToken() + alt Token Valid + TokenManager->>Client: Return cached token + else Token Expired + TokenManager->>IdentityProvider: RequestToken() + IdentityProvider->>Azure: Authenticate + Azure->>IdentityProvider: Return token + IdentityProvider->>TokenManager: Cache token + TokenManager->>Client: Return new token + end +``` + +### Component Responsibilities + +1. **Redis Client** + - Handles Redis connections + - Manages connection pooling + - Implements Redis protocol + +2. **StreamingCredentialsProvider** + - Provides authentication credentials + - Handles token refresh + - Manages authentication state + +3. **Token Manager** + - Caches tokens + - Handles token refresh + - Implements retry logic + - Manages token lifecycle + +4. **Identity Provider** + - Authenticates with Azure + - Handles different auth types + - Manages credentials + +## Authentication Providers + +### Provider Selection Guide + +```mermaid +graph TD + A[Choose Authentication] --> B{Managed Identity?} + B -->|Yes| C{System Assigned?} + B -->|No| D{Client Credentials?} + C -->|Yes| E[SystemAssignedIdentity] + C -->|No| F[UserAssignedIdentity] + D -->|Yes| G{Client Secret?} + D -->|No| H[DefaultAzureIdentity] + G -->|Yes| I[ClientSecret] + G -->|No| J[ClientCertificate] +``` + +### Provider Comparison + +| Provider Type | Best For | Security | Configuration | Performance | +|--------------|----------|----------|---------------|-------------| +| System Assigned | Azure-hosted apps | Highest | Minimal | Best | +| User Assigned | Shared identity | High | Moderate | Good | +| Client Secret | Service auth | High | Moderate | Good | +| Client Cert | High security | Highest | Complex | Good | +| Default Azure | Development | Moderate | Minimal | Good | + +## Configuration Guide + +### Environment Variables +```bash +# Required +AZURE_CLIENT_ID=your-client-id +REDIS_ENDPOINT=your-redis-endpoint:6380 + +# Optional +AZURE_TENANT_ID=your-tenant-id +AZURE_CLIENT_SECRET=your-client-secret +``` + +### Available Configuration Options + +#### 1. CredentialsProviderOptions +Base options for all credential providers: +```go +type CredentialsProviderOptions struct { + // Required: Client ID for authentication + ClientID string + + // Optional: Token manager configuration + TokenManagerOptions manager.TokenManagerOptions +} +``` + +#### 2. TokenManagerOptions +Options for token management: +```go +type TokenManagerOptions struct { + // Optional: Ratio of token lifetime to trigger refresh (0-1) + // Default: 0.7 (refresh at 70% of token lifetime) + ExpirationRefreshRatio float64 + + // Optional: Minimum time before expiration to trigger refresh + // Default: 0 (no lower bound, refresh based on ExpirationRefreshRatio) + LowerRefreshBound time.Duration + + // Optional: Custom response parser + IdentityProviderResponseParser shared.IdentityProviderResponseParser + + // Optional: Configuration for retry behavior + RetryOptions RetryOptions + + // Optional: Timeout for token requests + RequestTimeout time.Duration +} +``` + +#### 3. RetryOptions +Options for retry behavior: +```go +type RetryOptions struct { + // Optional: Function to determine if an error is retryable + // Default: Checks for network errors and timeouts + IsRetryable func(err error) bool + + // Optional: Maximum number of retry attempts + // Default: 3 + MaxAttempts int + + // Optional: Initial delay between retries + // Default: 1 second + InitialDelay time.Duration + + // Optional: Maximum delay between retries + // Default: 10 seconds + MaxDelay time.Duration + + // Optional: Multiplier for exponential backoff + // Default: 2.0 + BackoffMultiplier float64 +} +``` + +#### 4. ManagedIdentityProviderOptions +Options for managed identity authentication: +```go +type ManagedIdentityProviderOptions struct { + // Required: Type of managed identity + ManagedIdentityType ManagedIdentityType // SystemAssignedIdentity or UserAssignedIdentity + + // Optional: Client ID for user-assigned identity + UserAssignedClientID string + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +#### 5. ConfidentialIdentityProviderOptions +Options for confidential client authentication: +```go +type ConfidentialIdentityProviderOptions struct { + // Required: Client ID for authentication + ClientID string + + // Required: Type of credentials + CredentialsType string // identity.ClientSecretCredentialType or identity.ClientCertificateCredentialType + + // Required for ClientSecret: Client secret value + ClientSecret string + + // Required for ClientCertificate: Client certificate + // Type: []*x509.Certificate + ClientCert []*x509.Certificate + + // Required for ClientCertificate: Client private key + // Type: crypto.PrivateKey + ClientPrivateKey crypto.PrivateKey + + // Required: Authority configuration + Authority AuthorityConfiguration + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +#### 6. AuthorityConfiguration +Options for authority configuration: +```go +type AuthorityConfiguration struct { + // Required: Type of authority + AuthorityType AuthorityType // "default", "multi-tenant", or "custom" + + // Required: Azure AD tenant ID + // Use "common" for multi-tenant applications + TenantID string + + // Optional: Custom authority URL + // Required for custom authority type + Authority string +} +``` + +#### 7. DefaultAzureIdentityProviderOptions +Options for default Azure identity: +```go +type DefaultAzureIdentityProviderOptions struct { + // Optional: Azure identity provider options + AzureOptions *azidentity.DefaultAzureCredentialOptions + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +### Configuration Examples + +#### Basic Configuration +```go +options := entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBounds: 10000, + }, +} +``` + +#### Advanced Configuration +```go +options := entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBounds: 10000, + RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelay: 1000 * time.Millisecond, + MaxDelay: 30000 * time.Millisecond, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, + }, + }, +} +``` + +#### Authority Configuration +```go +// Multi-tenant application +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeMultiTenant, + TenantID: "common", +} + +// Single-tenant application +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), +} + +// Custom authority +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeCustom, + TenantID: os.Getenv("AZURE_TENANT_ID"), + Authority: fmt.Sprintf("%s/%s/v2.0", + os.Getenv("AZURE_AUTHORITY_HOST"), + os.Getenv("AZURE_TENANT_ID")), +} +``` + +## Examples + +### System Assigned Identity +```go +// Create provider for system assigned identity +provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + ManagedIdentityType: identity.SystemAssignedIdentity, +}) +``` + +### User Assigned Identity +```go +// Create provider for user assigned identity +provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + ManagedIdentityType: identity.UserAssignedIdentity, + UserAssignedClientID: os.Getenv("USER_ASSIGNED_CLIENT_ID"), +}) +``` + +### Client Secret Authentication +```go +// Create provider for client secret authentication +provider, err := entraid.NewConfidentialCredentialsProvider(entraid.ConfidentialIdentityProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: os.Getenv("AZURE_CLIENT_SECRET"), + Authority: identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), + }, +}) +``` + +### Client Certificate Authentication +```go +// Create provider for client certificate authentication +cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") +if err != nil { + log.Fatal(err) +} + +provider, err := entraid.NewConfidentialCredentialsProvider(entraid.ConfidentialIdentityProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + CredentialsType: identity.ClientCertificateCredentialType, + ClientCert: []*x509.Certificate{cert.Leaf}, + ClientPrivateKey: cert.PrivateKey, + Authority: identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), + }, +}) +``` + +### Advanced Usage with Custom Identity Provider + +This example shows how to implement your own IdentityProvider while leveraging our TokenManager and StreamingCredentialsProvider. This is useful when you need to authenticate with a custom token source but want to benefit from our token management and streaming capabilities. + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + "time" + + "github.com/redis-developer/go-redis-entraid/entraid" + "github.com/redis-developer/go-redis-entraid/entraid/identity" + "github.com/redis-developer/go-redis-entraid/entraid/manager" + "github.com/redis-developer/go-redis-entraid/entraid/shared" + "github.com/redis/go-redis/v9" +) + +// CustomIdentityProvider implements the IdentityProvider interface +type CustomIdentityProvider struct { + // Add any fields needed for your custom authentication + tokenEndpoint string + clientID string + clientSecret string +} + +// RequestToken implements the IdentityProvider interface +func (p *CustomIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + // Implement your custom token retrieval logic here + // This could be calling your own auth service, using a different auth protocol, etc. + + // For this example, we'll simulate getting a JWT token + token := "your.jwt.token" + + // Create a response using NewIDPResponse + return shared.NewIDPResponse(shared.ResponseTypeRawToken, token) +} + +func main() { + // Create your custom identity provider + customProvider := &CustomIdentityProvider{ + tokenEndpoint: "https://your-auth-endpoint.com/token", + clientID: os.Getenv("CUSTOM_CLIENT_ID"), + clientSecret: os.Getenv("CUSTOM_CLIENT_SECRET"), + } + + // Create token manager with your custom provider + tokenManager, err := manager.NewTokenManager(customProvider, manager.TokenManagerOptions{ + // Configure token refresh behavior + ExpirationRefreshRatio: 0.7, + LowerRefreshBound: time.Second * 10, + RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Second, + MaxDelay: time.Second * 10, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, + }, + RequestTimeout: time.Second * 30, + }) + if err != nil { + log.Fatalf("Failed to create token manager: %v", err) + } + + // Create credentials provider using our StreamingCredentialsProvider + provider, err := entraid.NewCredentialsProvider(tokenManager, entraid.CredentialsProviderOptions{ + // Add any additional options needed + OnReAuthenticationError: func(err error) error { + log.Printf("Re-authentication error: %v", err) + return err + }, + }) + if err != nil { + log.Fatalf("Failed to create credentials provider: %v", err) + } + + // Create Redis client with your custom provider + client := redis.NewClient(&redis.Options{ + Addr: os.Getenv("REDIS_ENDPOINT"), + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test the connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + log.Println("Connected to Redis with custom identity provider!") +} +``` + +Key points about this implementation: + +1. **Custom Identity Provider**: + - Implements the `IdentityProvider` interface with `RequestToken` method + - Returns a response using `shared.NewIDPResponse` with `ResponseTypeRawToken` + - Handles your custom authentication logic + +2. **Token Management**: + - Uses our `TokenManager` for automatic token refresh + - Benefits from our retry mechanisms + - Handles token caching and lifecycle + - Configurable refresh timing and retry behavior + +3. **Streaming Credentials**: + - Uses our `StreamingCredentialsProvider` for Redis integration + - Handles connection authentication + - Manages token streaming to Redis + +4. **Error Handling**: + - Implements proper error handling + - Uses our error callback mechanisms + - Provides logging and monitoring hooks + +This approach gives you the flexibility of custom authentication while benefiting from our robust token management and Redis integration features. + +## Testing + +### Unit Testing +```go +func TestManagedIdentityProvider(t *testing.T) { + // Create test provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: "test-client-id", + }, + }) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Test token retrieval + token, err := provider.GetToken(context.Background()) + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + if token == "" { + t.Error("Expected non-empty token") + } +} +``` + +### Integration Testing +```go +func TestRedisConnection(t *testing.T) { + // Create provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + }) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Create Redis client + client := redis.NewClient(&redis.Options{ + Addr: os.Getenv("REDIS_ENDPOINT"), + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + t.Fatalf("Failed to connect to Redis: %v", err) + } +} +``` + +## FAQ + +### Q: How do I handle token expiration? +A: The library handles token expiration automatically. Tokens are refreshed when they reach 70% of their lifetime (configurable via `ExpirationRefreshRatio`). You can also set a minimum time before expiration to trigger refresh using `LowerRefreshBound`. The token manager will automatically handle token refresh and caching. + +### Q: How do I handle connection failures? +A: The library includes built-in retry mechanisms in the TokenManager. You can configure retry behavior using `RetryOptions`: +```go +RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Second, + MaxDelay: time.Second * 10, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, +} +``` + +### Q: What happens if token refresh fails? +A: The library will retry according to the configured `RetryOptions`. If all retries fail, the error will be propagated to the client. You can customize the retry behavior by: +1. Setting the maximum number of attempts +2. Configuring the initial and maximum delay between retries using `time.Duration` values +3. Setting the backoff multiplier for exponential backoff +4. Providing a custom function to determine which errors are retryable + +### Q: How do I implement custom authentication? +A: You can create a custom identity provider by implementing the `IdentityProvider` interface: +```go +type IdentityProvider interface { + // RequestToken requests a token from the identity provider. + // The context is passed to the request to allow for cancellation and timeouts. + // It returns the token, the expiration time, and an error if any. + RequestToken(ctx context.Context) (IdentityProviderResponse, error) +} +``` + +The response types are defined as constants: +```go +const ( + // ResponseTypeAuthResult is the type of the auth result. + ResponseTypeAuthResult = "AuthResult" + // ResponseTypeAccessToken is the type of the access token. + ResponseTypeAccessToken = "AccessToken" + // ResponseTypeRawToken is the type of the response when you have a raw string. + ResponseTypeRawToken = "RawToken" +) +``` + +The `IdentityProviderResponse` interface and related interfaces provide methods to access the authentication result: +```go +// IdentityProviderResponse is the base interface that defines the type method +type IdentityProviderResponse interface { + // Type returns the type of identity provider response + Type() string +} + +// AuthResultIDPResponse defines the method for getting the auth result +type AuthResultIDPResponse interface { + AuthResult() public.AuthResult +} + +// AccessTokenIDPResponse defines the method for getting the access token +type AccessTokenIDPResponse interface { + AccessToken() azcore.AccessToken +} + +// RawTokenIDPResponse defines the method for getting the raw token +type RawTokenIDPResponse interface { + RawToken() string +} +``` + +You can create a new response using the `NewIDPResponse` function: +```go +// NewIDPResponse creates a new auth result based on the type provided. +// Type can be either AuthResult, AccessToken, or RawToken. +// Second argument is the result of the type provided in the first argument. +func NewIDPResponse(responseType string, result interface{}) (IdentityProviderResponse, error) +``` + +Here's an example of how to use these types in a custom identity provider: +```go +type CustomIdentityProvider struct { + tokenEndpoint string + clientID string + clientSecret string +} + +func (p *CustomIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + // Get the token from your custom auth service + token, err := p.getTokenFromCustomService() + if err != nil { + return nil, err + } + + // Create a response based on the token type + switch token.Type { + case "jwt": + return shared.NewIDPResponse(shared.ResponseTypeRawToken, token.Value) + case "access_token": + return shared.NewIDPResponse(shared.ResponseTypeAccessToken, token.Value) + case "auth_result": + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, token.Value) + default: + return nil, fmt.Errorf("unsupported token type: %s", token.Type) + } +} +``` + +### Q: Can I customize how token responses are parsed? +A: Yes, you can provide a custom `IdentityProviderResponseParser` in the `TokenManagerOptions`. This allows you to handle custom token formats or implement special parsing logic. + +### Q: What's the difference between managed identity types? +A: There are three main types of managed identities in Azure: + +1. **System Assigned Managed Identity**: + - Automatically created and managed by Azure + - Tied directly to a specific Azure resource (VM, App Service, etc.) + - Cannot be shared between resources + - Automatically deleted when the resource is deleted + - Best for single-resource applications with dedicated identity + +2. **User Assigned Managed Identity**: + - Created and managed independently of resources + - Can be assigned to multiple Azure resources + - Has its own lifecycle independent of resources + - Can be shared across multiple resources + - Best for applications that need a shared identity or run across multiple resources + +3. **Default Azure Identity**: + - Uses environment-based authentication + - Automatically tries multiple authentication methods in sequence: + 1. Environment variables + 2. Managed Identity + 3. Visual Studio Code + 4. Azure CLI + 5. Visual Studio + - Best for development and testing environments + - Provides flexibility during development without changing code + +The choice between these types depends on your specific use case: +- Use System Assigned for single-resource applications +- Use User Assigned for shared identity scenarios +- Use Default Azure Identity for development and testing \ No newline at end of file diff --git a/credentials_provider.go b/credentials_provider.go new file mode 100644 index 0000000..30ff8e5 --- /dev/null +++ b/credentials_provider.go @@ -0,0 +1,143 @@ +// Package entraid provides a credentials provider that manages token retrieval and notifies listeners +// of token updates. It implements the auth.StreamingCredentialsProvider interface and is designed +// for use with the Redis authentication system. +package entraid + +import ( + "fmt" + "sync" + + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" +) + +// Ensure entraidCredentialsProvider implements the auth.StreamingCredentialsProvider interface. +var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil) + +// entraidCredentialsProvider is a struct that implements the StreamingCredentialsProvider interface. +type entraidCredentialsProvider struct { + options CredentialsProviderOptions // Configuration options for the provider. + + tokenManager manager.TokenManager // Manages token retrieval. + stopTokenManager manager.StopFunc // Function to stop the token manager. + + // listeners is a slice of listeners that are notified when the token manager receives a new token. + listeners []auth.CredentialsListener // Slice of listeners notified on token updates. + + // rwLock is a mutex that is used to synchronize access to the listeners slice. + rwLock sync.RWMutex // Mutex for synchronizing access to the listeners slice. +} + +// onTokenNext is a method that is called when the token manager receives a new token. +// It notifies all registered listeners with the new token. +func (e *entraidCredentialsProvider) onTokenNext(t *token.Token) { + e.rwLock.RLock() + defer e.rwLock.RUnlock() + // Notify all listeners with the new token. + for _, listener := range e.listeners { + listener.OnNext(t) + } +} + +// onTokenError is a method that is called when the token manager encounters an error. +// It notifies all registered listeners with the error. +func (e *entraidCredentialsProvider) onTokenError(err error) { + e.rwLock.RLock() + defer e.rwLock.RUnlock() + + // Notify all listeners with the error + for _, listener := range e.listeners { + listener.OnError(err) + } +} + +// Subscribe subscribes a listener to the credentials provider. +// It returns the current credentials, a cancel function to unsubscribe, and an error if the subscription fails. +// +// Parameters: +// - listener: The listener that will receive updates about token changes. +// +// Returns: +// - auth.Credentials: The current credentials for the listener. +// - auth.CancelProviderFunc: A function that can be called to unsubscribe the listener. +// - error: An error if the subscription fails, such as if the token cannot be retrieved. +// +// Note: If the listener is already subscribed, it will not receive duplicate notifications. +func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { + // First try to get a token, only then subscribe the listener. + token, err := e.tokenManager.GetToken(false) + if err != nil { + return nil, nil, err + } + + e.rwLock.Lock() + // Check if the listener is already in the list of listeners. + alreadySubscribed := false + for _, l := range e.listeners { + if l == listener { + alreadySubscribed = true + break + } + } + + if !alreadySubscribed { + // add new listener + e.listeners = append(e.listeners, listener) + } + e.rwLock.Unlock() + + unsub := func() error { + // Remove the listener from the list of listeners. + e.rwLock.Lock() + defer e.rwLock.Unlock() + + for i, l := range e.listeners { + if l == listener { + e.listeners = append(e.listeners[:i], e.listeners[i+1:]...) + break + } + } + + // Clear the listeners slice if it's empty + if len(e.listeners) == 0 { + e.listeners = make([]auth.CredentialsListener, 0) + if e.stopTokenManager != nil { + err := e.stopTokenManager() + if err != nil { + return fmt.Errorf("couldn't cancel token manager: %w", err) + } + // Set the stopTokenManager to nil to indicate that it has been stopped. + // This prevents multiple calls to stopTokenManager. + e.stopTokenManager = nil + } + } + return nil + } + + return token, unsub, nil +} + +// NewCredentialsProvider creates a new credentials provider with the specified token manager and options. +// It returns a StreamingCredentialsProvider interface and an error if the token manager cannot be started. +// +// Parameters: +// - tokenManager: The TokenManager used to obtain tokens. +// - options: Options for configuring the credentials provider. +// +// Returns: +// - auth.StreamingCredentialsProvider: The newly created credentials provider. +// - error: An error if the token manager cannot be started. +func NewCredentialsProvider(tokenManager manager.TokenManager, options CredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + cp := &entraidCredentialsProvider{ + tokenManager: tokenManager, + options: options, + listeners: make([]auth.CredentialsListener, 0), + } + stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) + if err != nil { + return nil, fmt.Errorf("couldn't start token manager: %w", err) + } + cp.stopTokenManager = stopTM + return cp, nil +} diff --git a/credentials_provider_test.go b/credentials_provider_test.go new file mode 100644 index 0000000..bb4871d --- /dev/null +++ b/credentials_provider_test.go @@ -0,0 +1,582 @@ +package entraid + +import ( + "sync" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestCredentialsProviderErrorScenarios(t *testing.T) { + t.Run("token manager start error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: "invalid-type", // Invalid credentials type + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("token manager get token error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "", // Empty client secret + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent error handling", func(t *testing.T) { + // Create a test provider with invalid options + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", // Invalid managed identity type + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent token updates", func(t *testing.T) { + // Create a test provider with invalid options + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{}, // Empty scopes + }, + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + // bad options - empty scopes + assert.Error(t, err) + assert.Nil(t, provider) + }) +} + +func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { + t.Parallel() + + t.Run("Subscribe and Unsubscribe", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Multiple Listeners", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create multiple mock listeners + listener1 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + listener2 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe first listener + credentials1, cancel1, err := cp.Subscribe(listener1) + assert.NoError(t, err) + assert.NotNil(t, credentials1) + assert.NotNil(t, cancel1) + + // Subscribe second listener + credentials2, cancel2, err := cp.Subscribe(listener2) + assert.NoError(t, err) + assert.NotNil(t, credentials2) + assert.NotNil(t, cancel2) + + // Wait for initial tokens + token1, err := listener1.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token1) + + token2, err := listener2.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token2) + + // Unsubscribe first listener + err = cancel1() + assert.NoError(t, err) + + // Unsubscribe second listener + err = cancel2() + assert.NoError(t, err) + }) + + t.Run("Token Updates", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "initial-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "initial-token", tk) + + tm.lock.Lock() + // Update token + tm.token = token.New( + "test", + "test", + "updated-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + tm.lock.Unlock() + + // Wait for token update + tk, err = listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "updated-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Error Handling", func(t *testing.T) { + t.Parallel() + + // Create mock token manager with error + tm := &fakeTokenManager{ + err: assert.AnError, + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.Error(t, err) + assert.Nil(t, cp) + }) +} + +func TestCredentialsProviderOptions(t *testing.T) { + t.Run("default token manager factory", func(t *testing.T) { + options := CredentialsProviderOptions{} + factory := options.getTokenManagerFactory() + assert.NotNil(t, factory) + }) + + t.Run("custom token manager factory", func(t *testing.T) { + m := &fakeTokenManager{} + customFactory := func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return m, nil + } + options := CredentialsProviderOptions{ + tokenManagerFactory: customFactory, + } + tm, err := options.getTokenManagerFactory()(nil, manager.TokenManagerOptions{}) + assert.NotNil(t, tm) + assert.NoError(t, err) + assert.Equal(t, m, tm) + }) +} + +func TestCredentialsProviderSubscribe(t *testing.T) { + // Create a test provider + opts := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + t.Run("double subscribe and cancel resubscribe", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(testToken, nil) + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)). + Return(manager.StopFunc(mtm.Stop), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + // Subscribe the listener + tk, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + require.NotNil(t, tk) + require.NotNil(t, cancel) + // try to subscribe the same listener again + tk2, cancel2, err := provider.Subscribe(listener) + require.NoError(t, err) + require.NotNil(t, tk2) + require.NotNil(t, cancel2) + // Verify the listener received the token once + select { + case tk := <-listener.LastTokenCh: + assert.Equal(t, rawTokenString, tk, "listener received wrong token") + case err := <-listener.LastErrCh: + t.Fatalf("listener received error: %v", err) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener timed out waiting for token") + } + // verify it is not received again + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener received unexpected token: %v", tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener received unexpected error: %v", err) + case <-time.After(tokenExpiration / 2): + // No message received, which is expected + } + + }) + + t.Run("concurrent subscribe and cancel with error ", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(testToken, nil) + + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). + Return(manager.StopFunc(mtm.Stop), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + cancels := make([]auth.UnsubscribeFunc, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + _, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + cancels[idx] = cancel + }(i) + } + wg.Wait() + + // Verify all listeners received the token + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received token: %v", i, tk) + case err := <-listener.LastErrCh: + assert.Equal(t, errTokenError.Error(), err.Error(), "listener %d received wrong error", i) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener %d timed out waiting for token", i) + } + } + + // Cancel all subscriptions concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := cancels[idx]() + require.NoError(t, err) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) + + t.Run("concurrent subscribe and get token error ", func(t *testing.T) { + t.Parallel() + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(nil, assert.AnError) + + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). + Return(manager.StopFunc(mtm.Stop), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + tk, cancel, err := provider.Subscribe(listener) + require.Nil(t, tk) + require.Error(t, err) + require.Nil(t, cancel) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) + + t.Run("concurrent subscribe and cancel", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + cancels := make([]auth.UnsubscribeFunc, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + _, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + cancels[idx] = cancel + }(i) + } + wg.Wait() + + // Verify all listeners received the token + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + assert.Equal(t, rawTokenString, tk, "listener %d received wrong token", i) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received error: %v", i, err) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener %d timed out waiting for token", i) + } + } + + // Cancel all subscriptions concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := cancels[idx]() + require.NoError(t, err) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) +} diff --git a/entraid.go b/entraid.go new file mode 100644 index 0000000..06896c5 --- /dev/null +++ b/entraid.go @@ -0,0 +1,12 @@ +package entraid + +import "github.com/redis-developer/go-redis-entraid/shared" + +// IdentityProvider is an alias for the shared.IdentityProvider interface. +type IdentityProvider = shared.IdentityProvider + +// IdentityProviderResponse is an alias for the shared.IdentityProviderResponse interface. +type IdentityProviderResponse = shared.IdentityProviderResponse + +// IdentityProviderResponseParser is an alias for the shared.IdentityProviderResponseParser interface. +type IdentityProviderResponseParser = shared.IdentityProviderResponseParser diff --git a/entraid_test.go b/entraid_test.go new file mode 100644 index 0000000..5f64167 --- /dev/null +++ b/entraid_test.go @@ -0,0 +1,212 @@ +package entraid + +import ( + "errors" + "flag" + "sync" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/mock" +) + +// fakeTokenManager implements the TokenManager interface for testing +type fakeTokenManager struct { + token *token.Token + err error + lock sync.Mutex +} + +const rawTokenString = "mock-token" + +// numListeners is set to 3 for short tests and 12 for long tests +var numListeners = 12 + +// tokenExpiration is set to 100ms for long tests and 10ms for short tests +var tokenExpiration = 100 * time.Millisecond + +func init() { + testing.Init() + flag.Parse() + tokenExpiration = 100 * time.Millisecond + numListeners = 12 + if testing.Short() { + tokenExpiration = 10 * time.Millisecond + numListeners = 3 + } +} + +func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + if forceRefresh { + m.token = token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(100*time.Millisecond), + ) + } + return m.token, m.err +} + +func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) { + if m.err != nil { + return nil, m.err + } + done := make(chan struct{}) + go func() { + for { + select { + case <-time.After(tokenExpiration): + m.lock.Lock() + if m.err != nil { + listener.OnError(m.err) + return + } + listener.OnNext(m.token) + m.lock.Unlock() + case <-done: + // Exit the loop if done channel is closed + return + + } + } + }() + + return func() error { + close(done) + return nil + }, nil +} + +func (m *fakeTokenManager) Stop() error { + return nil +} + +// mockCredentialsListener implements the CredentialsListener interface for testing +type mockCredentialsListener struct { + LastTokenCh chan string + LastErrCh chan error +} + +func (m *mockCredentialsListener) readWithTimeout(timeout time.Duration) (string, error) { + select { + case tk := <-m.LastTokenCh: + return tk, nil + case err := <-m.LastErrCh: + return "", err + case <-time.After(timeout): + return "", errors.New("timeout waiting for token") + } +} + +func (m *mockCredentialsListener) OnNext(credentials auth.Credentials) { + if m.LastTokenCh == nil { + m.LastTokenCh = make(chan string) + } + m.LastTokenCh <- credentials.RawCredentials() +} + +func (m *mockCredentialsListener) OnError(err error) { + if m.LastErrCh == nil { + m.LastErrCh = make(chan error) + } + m.LastErrCh <- err +} + +// testFakeTokenManagerFactory is a factory function that returns a mock token manager +func testFakeTokenManagerFactory(tk *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + return &fakeTokenManager{ + token: tk, + err: err, + }, nil + } +} + +// mockTokenManager is a mock implementation of the TokenManager interface +type mockTokenManager struct { + mock.Mock + idp shared.IdentityProvider + done chan struct{} + options manager.TokenManagerOptions + listener manager.TokenListener + lock sync.Mutex +} + +func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + args := m.Called(forceRefresh) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*token.Token), args.Error(1) +} + +func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) { + args := m.Called(listener) + m.lock.Lock() + if m.done == nil { + m.done = make(chan struct{}) + } + if m.listener != nil { + defer m.lock.Unlock() + return nil, manager.ErrTokenManagerAlreadyStarted + } + if m.listener == nil { + m.listener = listener + } + m.lock.Unlock() + return args.Get(0).(manager.StopFunc), args.Error(1) +} +func (m *mockTokenManager) Stop() error { + m.lock.Lock() + defer m.lock.Unlock() + if m.listener == nil { + return manager.ErrTokenManagerAlreadyStopped + } + if m.listener != nil { + m.listener = nil + } + if m.done != nil { + close(m.done) + m.done = nil + } + return nil +} + +// mockTokenManagerFactory is a factory function that returns a mock token manager +func mockTokenManagerFactory(mtm *mockTokenManager) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + mtm.idp = provider + mtm.options = options + return mtm, nil + } +} + +var errTokenError = errors.New("token error") + +func mockTokenManagerLoop(mtm *mockTokenManager, tokenExpiration time.Duration, testToken *token.Token, err error) func(args mock.Arguments) { + return func(args mock.Arguments) { + go func() { + for { + select { + case <-mtm.done: + return + case <-time.After(tokenExpiration): + mtm.lock.Lock() + if err != nil { + mtm.listener.OnError(err) + } else { + mtm.listener.OnNext(testToken) + } + mtm.lock.Unlock() + } + } + }() + } +} diff --git a/go.mod b/go.mod index ad872dd..7e908f0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,30 @@ module github.com/redis-developer/go-redis-entraid go 1.18 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 + github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..36c644a --- /dev/null +++ b/go.sum @@ -0,0 +1,45 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 h1:iw4+KCeCoieuKodp1d5YhAa1TU/GgogCbw8RbGvsfLA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1/go.mod h1:AP8cDnDTGIVvayqKAhwzpcAyTJosXpvLYNmVFJb98x8= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.2.3 h1:BAUsn6/icUFtvUalVwCO0+hSF7qgU9DwwcEfCvtILtw= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/keybase/go-keychain v0.0.0-20231219164618-57a3676c3af6 h1:IsMZxCuZqKuao2vNdfD82fjjgPLfyHLpR41Z88viRWs= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 h1:/dxO9rhmlhKP5pyI7omDH3QQzC0AppWxHT1w5TBsdTU= +github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/identity/authority_configuration.go b/identity/authority_configuration.go new file mode 100644 index 0000000..bb229dd --- /dev/null +++ b/identity/authority_configuration.go @@ -0,0 +1,59 @@ +package identity + +import "fmt" + +const ( + // AuthorityTypeDefault is the default authority type. + // This is used to specify the authority type when requesting a token. + AuthorityTypeDefault = "default" + // AuthorityTypeMultiTenant is the multi-tenant authority type. + // This is used to specify the multi-tenant authority type when requesting a token. + // This type of authority is used to authenticate the identity when requesting a token. + AuthorityTypeMultiTenant = "multi-tenant" + // AuthorityTypeCustom is the custom authority type. + // This is used to specify the custom authority type when requesting a token. + AuthorityTypeCustom = "custom" +) + +// AuthorityConfiguration represents the authority configuration for the identity provider. +// It is used to configure the authority type and authority URL when requesting a token. +type AuthorityConfiguration struct { + // AuthorityType is the type of authority used to authenticate with the identity provider. + // This can be either "default", "multi-tenant", or "custom". + AuthorityType string + + // Authority is the authority used to authenticate with the identity provider. + // This is typically the URL of the identity provider. + // For example, "https://login.microsoftonline.com/{tenantID}/v2.0" + Authority string + + // TenantID is the tenant ID of the identity provider. + // This is used to identify the tenant when requesting a token. + // This is typically the ID of the Azure Active Directory tenant. + TenantID string +} + +// getAuthority returns the authority URL based on the authority type. +// The authority type can be either "default", "multi-tenant", or "custom". +func (a AuthorityConfiguration) getAuthority() (string, error) { + if a.AuthorityType == "" { + a.AuthorityType = AuthorityTypeDefault + } + + switch a.AuthorityType { + case AuthorityTypeDefault: + return "https://login.microsoftonline.com/common", nil + case AuthorityTypeMultiTenant: + if a.TenantID == "" { + return "", fmt.Errorf("tenant ID is required when using multi-tenant authority type") + } + return fmt.Sprintf("https://login.microsoftonline.com/%s", a.TenantID), nil + case AuthorityTypeCustom: + if a.Authority == "" { + return "", fmt.Errorf("authority is required when using custom authority type") + } + return a.Authority, nil + default: + return "", fmt.Errorf("invalid authority type: %s", a.AuthorityType) + } +} diff --git a/identity/authority_configuration_test.go b/identity/authority_configuration_test.go new file mode 100644 index 0000000..7ae4a67 --- /dev/null +++ b/identity/authority_configuration_test.go @@ -0,0 +1,160 @@ +package identity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAuthorityConfiguration(t *testing.T) { + t.Parallel() + tests := []struct { + name string + authorityType string + tenantID string + authority string + expected string + expectError bool + }{ + { + name: "Default Authority", + authorityType: AuthorityTypeDefault, + expected: "https://login.microsoftonline.com/common", + expectError: false, + }, + { + name: "Multi-Tenant Authority", + authorityType: AuthorityTypeMultiTenant, + tenantID: "12345", + expected: "https://login.microsoftonline.com/12345", + expectError: false, + }, + { + name: "Custom Authority", + authorityType: AuthorityTypeCustom, + authority: "https://custom-authority.com", + expected: "https://custom-authority.com", + expectError: false, + }, + { + name: "Invalid Authority Type", + authorityType: "invalid", + expectError: true, + }, + { + name: "Missing Tenant ID for Multi-Tenant", + authorityType: AuthorityTypeMultiTenant, + expectError: true, + }, + { + name: "Missing Authority for Custom", + authorityType: AuthorityTypeCustom, + expectError: true, + }, + { + name: "Default Authority Type with Tenant ID", + authorityType: AuthorityTypeDefault, + tenantID: "12345", + expected: "https://login.microsoftonline.com/common", + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ac := AuthorityConfiguration{ + AuthorityType: test.authorityType, + TenantID: test.tenantID, + Authority: test.authority, + } + result, err := ac.getAuthority() + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, test.expected, result) + } + }) + } +} + +func TestAuthorityConfigurationDefault(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{} + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} + +func TestAuthorityConfigurationMultiTenant(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeMultiTenant, + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/12345", result) +} + +func TestAuthorityConfigurationCustom(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://custom-authority.com", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://custom-authority.com", result) +} + +func TestAuthorityConfigurationInvalid(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: "invalid", + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationMissingTenantID(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeMultiTenant, + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationMissingAuthority(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationDefaultAuthorityType(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} + +func TestAuthorityConfigurationDefaultAuthorityTypeWithTenantID(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeDefault, + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} diff --git a/identity/azure_default_identity_provider.go b/identity/azure_default_identity_provider.go new file mode 100644 index 0000000..26e6a57 --- /dev/null +++ b/identity/azure_default_identity_provider.go @@ -0,0 +1,77 @@ +package identity + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// DefaultAzureIdentityProviderOptions represents the options for the DefaultAzureIdentityProvider. +type DefaultAzureIdentityProviderOptions struct { + // AzureOptions is the options used to configure the Azure identity provider. + AzureOptions *azidentity.DefaultAzureCredentialOptions + // Scopes is the list of scopes used to request a token from the identity provider. + Scopes []string + + // credFactory is a factory for creating the default Azure credential. + // This is used for testing purposes, to allow mocking the credential creation. + // If not provided, the default implementation - azidentity.NewDefaultAzureCredential will be used + credFactory credFactory +} + +type credFactory interface { + NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) +} + +type azureCredential interface { + GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) +} + +type defaultCredFactory struct{} + +func (d *defaultCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) { + return azidentity.NewDefaultAzureCredential(options) +} + +type DefaultAzureIdentityProvider struct { + options *azidentity.DefaultAzureCredentialOptions + credFactory credFactory + scopes []string +} + +// NewDefaultAzureIdentityProvider creates a new DefaultAzureIdentityProvider. +func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) (*DefaultAzureIdentityProvider, error) { + if opts.Scopes == nil { + opts.Scopes = []string{RedisScopeDefault} + } + + return &DefaultAzureIdentityProvider{ + options: opts.AzureOptions, + scopes: opts.Scopes, + credFactory: opts.credFactory, + }, nil +} + +// RequestToken requests a token from the Azure Default Identity provider. +// It returns the token, the expiration time, and an error if any. +func (a *DefaultAzureIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + credFactory := a.credFactory + if credFactory == nil { + credFactory = &defaultCredFactory{} + } + cred, err := credFactory.NewDefaultAzureCredential(a.options) + if err != nil { + return nil, fmt.Errorf("failed to create default azure credential: %w", err) + } + + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: a.scopes}) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAccessToken, &token) +} diff --git a/identity/azure_default_identity_provider_test.go b/identity/azure_default_identity_provider_test.go new file mode 100644 index 0000000..67a43f8 --- /dev/null +++ b/identity/azure_default_identity_provider_test.go @@ -0,0 +1,102 @@ +package identity + +import ( + "context" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewDefaultAzureIdentityProvider(t *testing.T) { + t.Parallel() + // Create a new DefaultAzureIdentityProvider with default options + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{}) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + // Check if the provider is not nil + if provider == nil { + t.Fatal("provider should not be nil") + } + + if provider.scopes == nil { + t.Fatal("provider.scopes should not be nil") + } + + assert.Contains(t, provider.scopes, RedisScopeDefault, "provider should contain default scope") +} +func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { + t.Parallel() + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{}) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + // Request a token from the provider in incorrect environment + // should fail. + token, err := provider.RequestToken(context.Background()) + assert.Nil(t, token, "token should be nil") + assert.Error(t, err, "failed to request token") + + // use mockAzureCredential to simulate the environment + mToken := azcore.AccessToken{ + Token: testJWTToken, + } + mCreds := &mockAzureCredential{} + mCreds.On("GetToken", mock.Anything, mock.Anything).Return(mToken, nil) + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) + provider.credFactory = mCredFactory + resp, err := provider.RequestToken(context.Background()) + assert.NotNil(t, resp, "resp should not be nil") + assert.NoError(t, err, "failed to request resp") + assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") + assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access token should be equal to testJWTToken") +} + +func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { + // Create a new DefaultAzureIdentityProvider with custom scopes + scopes := []string{"https://example.com/.default"} + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{ + Scopes: scopes, + }) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + t.Run("RequestToken with custom scopes", func(t *testing.T) { + // Request a token from the provider + token, err := provider.RequestToken(context.Background()) + assert.Nil(t, token, "token should be nil") + assert.Error(t, err, "failed to request token") + + // use mockAzureCredential to simulate the environment + mToken := azcore.AccessToken{ + Token: testJWTToken, + } + mCreds := &mockAzureCredential{} + mCreds.On("GetToken", mock.Anything, policy.TokenRequestOptions{Scopes: scopes}).Return(mToken, nil) + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) + provider.credFactory = mCredFactory + resp, err := provider.RequestToken(context.Background()) + assert.NotNil(t, resp, "resp should not be nil") + assert.NoError(t, err, "failed to request resp") + assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") + assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access resp should be equal to testJWTToken") + }) + t.Run("RequestToken with error from credFactory", func(t *testing.T) { + // use mockAzureCredential to simulate the environment + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError) + provider.credFactory = mCredFactory + resp, err := provider.RequestToken(context.Background()) + assert.Nil(t, resp, "resp should be nil") + assert.Error(t, err, "failed to request resp") + }) +} diff --git a/identity/confidential_identity_provider.go b/identity/confidential_identity_provider.go new file mode 100644 index 0000000..532b0c8 --- /dev/null +++ b/identity/confidential_identity_provider.go @@ -0,0 +1,169 @@ +package identity + +import ( + "context" + "crypto" + "crypto/x509" + "fmt" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// ConfidentialIdentityProviderOptions represents the options for the confidential identity provider. +type ConfidentialIdentityProviderOptions struct { + // ClientID is the client ID used to authenticate with the identity provider. + ClientID string + + // CredentialsType is the type of credentials used to authenticate with the identity provider. + // This can be either "ClientSecret" or "ClientCertificate". + CredentialsType string + + // ClientSecret is the client secret used to authenticate with the identity provider. + ClientSecret string + + // ClientCert is the client certificate used to authenticate with the identity provider. + ClientCert []*x509.Certificate + // ClientPrivateKey is the private key used to authenticate with the identity provider. + ClientPrivateKey crypto.PrivateKey + + // Scopes is the list of scopes used to request a token from the identity provider. + Scopes []string + + // Authority is the authority used to authenticate with the identity provider. + Authority AuthorityConfiguration + + // confidentialCredFactory is a factory for creating the confidential credential. + // This is used for testing purposes, to allow mocking the credential creation. + confidentialCredFactory confidentialCredFactory +} + +// ConfidentialIdentityProvider represents a confidential identity provider. +type ConfidentialIdentityProvider struct { + // clientID is the client ID used to authenticate with the identity provider. + clientID string + + // credential is the credential used to authenticate with the identity provider. + credential confidential.Credential + + // scopes is the list of scopes used to request a token from the identity provider. + scopes []string + + // client confidential is the client used to request a token from the identity provider. + client confidentialTokenClient +} + +// confidentialCredFacotory is a factory for creating the confidential credential. +// Introduced for testing purposes. This allows mocking the credential creation, default behavior is to use the confidential.NewCredFromSecret and confidential.NewCredFromCert methods. +type confidentialCredFactory interface { + NewCredFromSecret(clientSecret string) (confidential.Credential, error) + NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) +} + +// confidentialTokenClient is an interface that defines the methods for a confidential token client. +// It is used to acquire a token using the client credentials. +// Introduced for testing purposes. This allows mocking the token client, default behavior is to use the +// client returned by confidential.New method. +type confidentialTokenClient interface { + // AcquireTokenByCredential acquires a token using the client credentials. + // It returns the token and an error if any. + AcquireTokenByCredential(ctx context.Context, scopes []string, opts ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error) +} + +type defaultConfidentialCredFactory struct{} + +func (d *defaultConfidentialCredFactory) NewCredFromSecret(clientSecret string) (confidential.Credential, error) { + return confidential.NewCredFromSecret(clientSecret) +} + +func (d *defaultConfidentialCredFactory) NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) { + return confidential.NewCredFromCert(clientCert, clientPrivateKey) +} + +// NewConfidentialIdentityProvider creates a new confidential identity provider. +// It is used to configure the identity provider when requesting a token. +// It is used to specify the client ID, tenant ID, and scopes for the identity. +// It is also used to specify the type of credentials used to authenticate with the identity provider. +// The credentials can be either a client secret or a client certificate. +// The authority is used to authenticate with the identity provider. +func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) (*ConfidentialIdentityProvider, error) { + var credential confidential.Credential + var credFactory confidentialCredFactory + var authority string + var err error + + if opts.ClientID == "" { + return nil, fmt.Errorf("client ID is required") + } + + if opts.CredentialsType != ClientSecretCredentialType && opts.CredentialsType != ClientCertificateCredentialType { + return nil, fmt.Errorf("invalid credentials type") + } + + // Get the authority from the authority configuration. + authority, err = opts.Authority.getAuthority() + if err != nil { + return nil, fmt.Errorf("failed to get authority: %w", err) + } + + credFactory = &defaultConfidentialCredFactory{} + if opts.confidentialCredFactory != nil { + credFactory = opts.confidentialCredFactory + } + + switch opts.CredentialsType { + case ClientSecretCredentialType: + // ClientSecretCredentialType is the type of credentials that uses a client secret to authenticate. + if opts.ClientSecret == "" { + return nil, fmt.Errorf("client secret is required when using client secret credentials") + } + + credential, err = credFactory.NewCredFromSecret(opts.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to create client secret credential: %w", err) + } + case ClientCertificateCredentialType: + // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. + if len(opts.ClientCert) == 0 { + return nil, fmt.Errorf("non-empty client certificate is required when using client certificate credentials") + } + if opts.ClientPrivateKey == nil { + return nil, fmt.Errorf("client private key is required when using client certificate credentials") + } + credential, err = credFactory.NewCredFromCert(opts.ClientCert, opts.ClientPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate credential: %w", err) + } + } + + client, err := confidential.New(authority, opts.ClientID, credential) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + if opts.Scopes == nil { + opts.Scopes = []string{RedisScopeDefault} + } + + return &ConfidentialIdentityProvider{ + clientID: opts.ClientID, + credential: credential, + scopes: opts.Scopes, + client: &client, + }, nil +} + +// RequestToken requests a token from the identity provider. +// It returns the identity provider response, including the auth result. +func (c *ConfidentialIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + if c.client == nil { + return nil, fmt.Errorf("client is not initialized") + } + + result, err := c.client.AcquireTokenByCredential(ctx, c.scopes) + if err != nil { + return nil, fmt.Errorf("failed to acquire token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, &result) +} diff --git a/identity/confidential_identity_provider_test.go b/identity/confidential_identity_provider_test.go new file mode 100644 index 0000000..cef1976 --- /dev/null +++ b/identity/confidential_identity_provider_test.go @@ -0,0 +1,309 @@ +package identity + +import ( + "context" + "crypto/x509" + "fmt" + "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewConfidentialIdentityProvider(t *testing.T) { + t.Run("base", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + }) + + t.Run("with client certificate", func(t *testing.T) { + t.Parallel() + credFactory := &mockConfidentialCredentialFactory{} + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{&x509.Certificate{}}, + ClientPrivateKey: "private-key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + confidentialCredFactory: credFactory, + } + credFactory.On("NewCredFromCert", opts.ClientCert, opts.ClientPrivateKey).Return(confidential.Credential{}, nil) + provider, err := NewConfidentialIdentityProvider(opts) + // confidential.New will fail since the credentials are invalid + assert.ErrorContains(t, err, "failed to create client:") + assert.Nil(t, provider) + }) + + t.Run("with failing client certificate", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{&x509.Certificate{}}, + ClientPrivateKey: "private-key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + // invalid certificate should fail + provider, err := NewConfidentialIdentityProvider(opts) + assert.ErrorContains(t, err, "failed to create client certificate credential:") + assert.Nil(t, provider) + }) + + t.Run("with invalid credentials type", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "invalid-credentials-type", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with missing client id", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + CredentialsType: "ClientSecret", + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with bad authority type", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{AuthorityType: "bad-authority-type"}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + t.Run("with missing client secret", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + Scopes: []string{"scope1", "scope2"}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with credentials from secret error", func(t *testing.T) { + t.Parallel() + credFactory := &mockConfidentialCredentialFactory{} + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + confidentialCredFactory: credFactory, + } + credFactory.On("NewCredFromSecret", "client-secret").Return(confidential.Credential{}, fmt.Errorf("error creating credential")) + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + credFactory.AssertExpectations(t) + }) + + t.Run("empty certificate", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: nil, + ClientPrivateKey: "private key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("empty private key", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{&x509.Certificate{}}, + ClientPrivateKey: nil, + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + t.Run("validate default scopes", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + if len(provider.scopes) == 0 { + t.Errorf("NewConfidentialIdentityProvider() provider.Scopes = %v, want non-empty", provider.scopes) + return + } + assert.Contains(t, provider.scopes, RedisScopeDefault) + }) +} + +func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { + t.Run("with mock client", func(t *testing.T) { + t.Parallel() + mClient := &mockConfidentialTokenClient{} + + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://test-authority.dev/test", + }, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + expiresOn := time.Now().Add(time.Hour) + provider.client = mClient + mClient.On("AcquireTokenByCredential", mock.Anything, mock.Anything). + Return(confidential.AuthResult{ + ExpiresOn: expiresOn, + }, nil) + token, err := provider.RequestToken(context.Background()) + if err != nil { + t.Errorf("RequestToken() error = %v", err) + return + } + assert.NotEmpty(t, token, "RequestToken() token should not be empty") + assert.Equal(t, token.Type(), shared.ResponseTypeAuthResult, "RequestToken() token type should be AuthResult") + assert.Equal(t, token.(shared.AuthResultIDPResponse).AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match") + }) + t.Run("with error", func(t *testing.T) { + t.Parallel() + mClient := &mockConfidentialTokenClient{} + + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://test-authority.dev/test", + }, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + provider.client = mClient + mClient.On("AcquireTokenByCredential", mock.Anything, mock.Anything). + Return(confidential.AuthResult{}, fmt.Errorf("error acquiring token")) + token, err := provider.RequestToken(context.Background()) + assert.ErrorContains(t, err, "failed to acquire token:") + assert.Empty(t, token, "RequestToken() token should be empty") + }) + t.Run("without initialization", func(t *testing.T) { + t.Parallel() + provider := &ConfidentialIdentityProvider{} + token, err := provider.RequestToken(context.Background()) + assert.ErrorContains(t, err, "client is not initialized") + assert.Empty(t, token, "RequestToken() token should be empty") + }) +} diff --git a/identity/identity_test.go b/identity/identity_test.go new file mode 100644 index 0000000..4d33a65 --- /dev/null +++ b/identity/identity_test.go @@ -0,0 +1,86 @@ +package identity + +import ( + "context" + "crypto" + "crypto/x509" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/stretchr/testify/mock" +) + +// testJWTToken is a JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1743515011, +// "exp": 1775051011, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0MzUxNTAxMSwiZXhwIjoxNzc1MDUxMDExLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.6RG721V2eFlSLsCRmo53kSRRrTZIe1UPdLZCUEvIarU" + +type mockAzureCredential struct { + mock.Mock +} + +func (m *mockAzureCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return azcore.AccessToken{}, args.Error(1) + } + return args.Get(0).(azcore.AccessToken), args.Error(1) +} + +type mockCredFactory struct { + // Mock implementation of the credFactory interface + mock.Mock +} + +func (m *mockCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) { + args := m.Called(options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(azureCredential), args.Error(1) +} + +type mockConfidentialCredentialFactory struct { + // Mock implementation of the confidentialCredFactory interface + mock.Mock +} + +func (m *mockConfidentialCredentialFactory) NewCredFromSecret(clientSecret string) (confidential.Credential, error) { + args := m.Called(clientSecret) + if args.Get(0) == nil { + return confidential.Credential{}, args.Error(1) + } + return args.Get(0).(confidential.Credential), args.Error(1) +} + +func (m *mockConfidentialCredentialFactory) NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) { + args := m.Called(clientCert, clientPrivateKey) + if args.Get(0) == nil { + return confidential.Credential{}, args.Error(1) + } + return args.Get(0).(confidential.Credential), args.Error(1) +} + +type mockConfidentialTokenClient struct { + // Mock implementation of the confidentialTokenClient interface + mock.Mock +} + +func (m *mockConfidentialTokenClient) AcquireTokenByCredential(ctx context.Context, scopes []string, options ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return confidential.AuthResult{}, args.Error(1) + } + return args.Get(0).(confidential.AuthResult), args.Error(1) +} diff --git a/identity/managed_identity_provider.go b/identity/managed_identity_provider.go new file mode 100644 index 0000000..cb5a4af --- /dev/null +++ b/identity/managed_identity_provider.go @@ -0,0 +1,124 @@ +package identity + +import ( + "context" + "errors" + "fmt" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// ManagedIdentityClient is an interface that defines the methods for a managed identity client. +// It is used to acquire a token using the managed identity. +type ManagedIdentityClient interface { + // AcquireToken acquires a token using the managed identity. + // It returns the token and an error if any. + AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) +} + +// ManagedIdentityProviderOptions represents the options for the managed identity provider. +// It is used to configure the identity provider when requesting a token. +type ManagedIdentityProviderOptions struct { + // UserAssignedClientID is the client ID of the user assigned identity. + // This is used to identify the identity when requesting a token. + UserAssignedClientID string + // ManagedIdentityType is the type of managed identity. + // This can be either SystemAssigned or UserAssigned. + ManagedIdentityType string + // Scopes is a list of scopes that the identity has access to. + // This is used to specify the permissions that the identity has when requesting a token. + Scopes []string +} + +// ManagedIdentityProvider represents a managed identity provider. +type ManagedIdentityProvider struct { + // userAssignedClientID is the client ID of the user assigned identity. + // This is used to identify the identity when requesting a token. + userAssignedClientID string + + // managedIdentityType is the type of managed identity. + // This can be either SystemAssigned or UserAssigned. + managedIdentityType string + + // scopes is a list of scopes that the identity has access to. + // This is used to specify the permissions that the identity has when requesting a token. + scopes []string + + // client is the managed identity client used to request a token. + client ManagedIdentityClient +} + +// realManagedIdentityClient is a wrapper around the real mi.Client that implements our interface +type realManagedIdentityClient struct { + client ManagedIdentityClient +} + +func (c *realManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + return c.client.AcquireToken(ctx, resource, opts...) +} + +// NewManagedIdentityProvider creates a new managed identity provider for Azure with managed identity. +// It is used to configure the identity provider when requesting a token. +func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedIdentityProvider, error) { + var client ManagedIdentityClient + + if opts.ManagedIdentityType != SystemAssignedIdentity && opts.ManagedIdentityType != UserAssignedIdentity { + return nil, errors.New("invalid managed identity type") + } + + switch opts.ManagedIdentityType { + case SystemAssignedIdentity: + // SystemAssignedIdentity is the type of identity that is automatically managed by Azure. + // This type of identity is automatically created and managed by Azure. + // It is used to authenticate the identity when requesting a token. + miClient, err := mi.New(mi.SystemAssigned()) + if err != nil { + return nil, fmt.Errorf("couldn't create managed identity client: %w", err) + } + client = &realManagedIdentityClient{client: miClient} + case UserAssignedIdentity: + // UserAssignedIdentity is required to be specified when using a user assigned identity. + if opts.UserAssignedClientID == "" { + return nil, errors.New("user assigned client ID is required when using user assigned identity") + } + // UserAssignedIdentity is the type of identity that is managed by the user. + miClient, err := mi.New(mi.UserAssignedClientID(opts.UserAssignedClientID)) + if err != nil { + return nil, fmt.Errorf("couldn't create managed identity client: %w", err) + } + client = &realManagedIdentityClient{client: miClient} + } + + return &ManagedIdentityProvider{ + userAssignedClientID: opts.UserAssignedClientID, + managedIdentityType: opts.ManagedIdentityType, + scopes: opts.Scopes, + client: client, + }, nil +} + +// RequestToken requests a token from the managed identity provider. +// It returns IdentityProviderResponse, which contains the Acc and the expiration time. +func (m *ManagedIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + if m.client == nil { + return nil, errors.New("managed identity client is not initialized") + } + + // default resource is RedisResource == "https://redis.azure.com" + // if no scopes are provided, use the default resource + // if scopes are provided, use the first scope as the resource + resource := RedisResource + if len(m.scopes) > 0 { + resource = m.scopes[0] + } + // acquire token using the managed identity client + // the resource is the URL of the resource that the identity has access to + authResult, err := m.client.AcquireToken(ctx, resource) + if err != nil { + return nil, fmt.Errorf("couldn't acquire token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, &authResult) +} diff --git a/identity/managed_identity_provider_test.go b/identity/managed_identity_provider_test.go new file mode 100644 index 0000000..80dd661 --- /dev/null +++ b/identity/managed_identity_provider_test.go @@ -0,0 +1,305 @@ +package identity + +import ( + "context" + "errors" + "testing" + "time" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockManagedIdentityClient is a mock implementation of the managed identity client +type MockManagedIdentityClient struct { + mock.Mock +} + +func (m *MockManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + args := m.Called(ctx, resource) + return args.Get(0).(public.AuthResult), args.Error(1) +} + +func TestNewManagedIdentityProvider(t *testing.T) { + tests := []struct { + name string + opts ManagedIdentityProviderOptions + expectedError string + }{ + { + name: "System assigned identity", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: SystemAssignedIdentity, + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "", + }, + { + name: "User assigned identity with client ID", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: UserAssignedIdentity, + UserAssignedClientID: "test-client-id", + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "", + }, + { + name: "User assigned identity without client ID", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: UserAssignedIdentity, + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "user assigned client ID is required when using user assigned identity", + }, + { + name: "Invalid identity type", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "invalid managed identity type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewManagedIdentityProvider(tt.opts) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, tt.opts.ManagedIdentityType, provider.managedIdentityType) + assert.Equal(t, tt.opts.UserAssignedClientID, provider.userAssignedClientID) + assert.Equal(t, tt.opts.Scopes, provider.scopes) + assert.NotNil(t, provider.client) + } + }) + } +} + +func TestRequestToken(t *testing.T) { + tests := []struct { + name string + provider *ManagedIdentityProvider + expectedError string + }{ + { + name: "Success with default resource", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: new(MockManagedIdentityClient), + }, + expectedError: "", + }, + { + name: "Success with custom resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"custom-resource"}, + client: new(MockManagedIdentityClient), + }, + expectedError: "", + }, + { + name: "Error when client is nil", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: nil, + }, + expectedError: "managed identity client is not initialized", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the mock expectations if we have a mock client + if tt.provider.client != nil { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + expectedResource := RedisResource + if len(tt.provider.scopes) > 0 { + expectedResource = tt.provider.scopes[0] + } + + if tt.expectedError == "" { + mockClient.On("AcquireToken", mock.Anything, expectedResource). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + } else { + mockClient.On("AcquireToken", mock.Anything, expectedResource). + Return(public.AuthResult{}, errors.New(tt.expectedError)) + } + } + + response, err := tt.provider.RequestToken(context.Background()) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + } + + // Verify mock expectations + if tt.provider.client != nil { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + mockClient.AssertExpectations(t) + } + }) + } +} + +func TestRequestToken_ErrorCases(t *testing.T) { + tests := []struct { + name string + provider *ManagedIdentityProvider + setupMock func(*MockManagedIdentityClient) + expectedError string + }{ + { + name: "AcquireToken fails", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource). + Return(public.AuthResult{}, errors.New("failed to acquire token")) + }, + expectedError: "couldn't acquire token: failed to acquire token", + }, + { + name: "AcquireToken fails with custom resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"custom-resource"}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "custom-resource"). + Return(public.AuthResult{}, errors.New("failed to acquire token")) + }, + expectedError: "couldn't acquire token: failed to acquire token", + }, + { + name: "AcquireToken fails with invalid resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"invalid-resource"}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "invalid-resource"). + Return(public.AuthResult{}, errors.New("invalid resource")) + }, + expectedError: "couldn't acquire token: invalid resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + tt.setupMock(mockClient) + + response, err := tt.provider.RequestToken(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, response) + mockClient.AssertExpectations(t) + }) + } +} + +// MockMIClient is a mock implementation of the mi.Client interface +type MockMIClient struct { + mock.Mock +} + +func (m *MockMIClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + args := m.Called(ctx, resource) + return args.Get(0).(public.AuthResult), args.Error(1) +} + +func (m *MockMIClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func TestRealManagedIdentityClient(t *testing.T) { + // Create a mock managed identity client + mockMIClient := new(MockManagedIdentityClient) + client := &realManagedIdentityClient{client: mockMIClient} + + tests := []struct { + name string + resource string + setupMock func(*MockManagedIdentityClient) + expectedError string + }{ + { + name: "Success with default resource", + resource: RedisResource, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + }, + }, + { + name: "Success with custom resource", + resource: "custom-resource", + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "custom-resource", mock.Anything). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + }, + }, + { + name: "Error from underlying client", + resource: RedisResource, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything). + Return(public.AuthResult{}, errors.New("underlying client error")) + }, + expectedError: "underlying client error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the mock for each test + mockMIClient.ExpectedCalls = nil + mockMIClient.Calls = nil + + // Set up the mock + tt.setupMock(mockMIClient) + + // Call AcquireToken with empty options slice to match mock setup + result, err := client.AcquireToken(context.Background(), tt.resource, []mi.AcquireTokenOption{}...) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Equal(t, public.AuthResult{}, result) + } else { + assert.NoError(t, err) + assert.NotEqual(t, public.AuthResult{}, result) + assert.Equal(t, "test-token", result.AccessToken) + } + + // Verify mock expectations + mockMIClient.AssertExpectations(t) + }) + } +} diff --git a/identity/providers.go b/identity/providers.go new file mode 100644 index 0000000..24126f3 --- /dev/null +++ b/identity/providers.go @@ -0,0 +1,25 @@ +package identity + +// CredentialsProviderOptions is a struct that holds the options for the credentials provider. + +const ( + // SystemAssignedIdentity is the type of identity that is automatically managed by Azure. + SystemAssignedIdentity = "SystemAssigned" + // UserAssignedIdentity is the type of identity that is managed by the user. + UserAssignedIdentity = "UserAssigned" + + // ClientSecretCredentialType is the type of credentials that uses a client secret to authenticate. + ClientSecretCredentialType = "ClientSecret" + // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. + ClientCertificateCredentialType = "ClientCertificate" + + // RedisScopeDefault is the default scope for Redis. + // This is used to specify the scope that the identity has access to when requesting a token. + // The scope is typically the URL of the resource that the identity has access to. + RedisScopeDefault = "https://redis.azure.com/.default" + + // RedisResource is the default resource for Redis. + // This is used to specify the resource that the identity has access to when requesting a token. + // The resource is typically the URL of the resource that the identity has access to. + RedisResource = "https://redis.azure.com" +) diff --git a/identity/providers_test.go b/identity/providers_test.go new file mode 100644 index 0000000..0712d0f --- /dev/null +++ b/identity/providers_test.go @@ -0,0 +1,52 @@ +package identity + +import ( + "testing" +) + +func TestConstants(t *testing.T) { + tests := []struct { + name string + got string + expected string + }{ + { + name: "SystemAssignedIdentity", + got: SystemAssignedIdentity, + expected: "SystemAssigned", + }, + { + name: "UserAssignedIdentity", + got: UserAssignedIdentity, + expected: "UserAssigned", + }, + { + name: "ClientSecretCredentialType", + got: ClientSecretCredentialType, + expected: "ClientSecret", + }, + { + name: "ClientCertificateCredentialType", + got: ClientCertificateCredentialType, + expected: "ClientCertificate", + }, + { + name: "RedisScopeDefault", + got: RedisScopeDefault, + expected: "https://redis.azure.com/.default", + }, + { + name: "RedisResource", + got: RedisResource, + expected: "https://redis.azure.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.expected) + } + }) + } +} diff --git a/install-git-hook.sh b/install-git-hook.sh new file mode 100755 index 0000000..5598436 --- /dev/null +++ b/install-git-hook.sh @@ -0,0 +1,2 @@ +chmod ug+x ./.githooks/* +git config core.hooksPath ./.githooks diff --git a/internal/idp_response.go b/internal/idp_response.go new file mode 100644 index 0000000..8ae0cdc --- /dev/null +++ b/internal/idp_response.go @@ -0,0 +1,95 @@ +package internal + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" +) + +// IDPResp represents a response from an Identity Provider (IDP) +// It can contain either an AuthResult, AccessToken, or a raw token string +type IDPResp struct { + // resultType indicates which type of response this is + resultType string + authResultVal *public.AuthResult + accessTokenVal *azcore.AccessToken + rawTokenVal string +} + +// NewIDPResp creates a new IDPResp with the given values +// It validates the input and ensures the response type matches the provided value +func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) { + if result == nil { + return nil, fmt.Errorf("result cannot be nil") + } + + r := &IDPResp{resultType: resultType} + + switch resultType { + case "AuthResult": + switch v := result.(type) { + case *public.AuthResult: + r.authResultVal = v + case public.AuthResult: + r.authResultVal = &v + default: + return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result) + } + case "AccessToken": + switch v := result.(type) { + case *azcore.AccessToken: + r.accessTokenVal = v + r.rawTokenVal = v.Token + case azcore.AccessToken: + r.accessTokenVal = &v + r.rawTokenVal = v.Token + default: + return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result) + } + case "RawToken": + switch v := result.(type) { + case string: + r.rawTokenVal = v + case *string: + if v == nil { + return nil, fmt.Errorf("raw token cannot be nil") + } + r.rawTokenVal = *v + default: + return nil, fmt.Errorf("invalid raw token type: expected string or *string, got %T", result) + } + default: + return nil, fmt.Errorf("unsupported identity provider response type: %s", resultType) + } + + return r, nil +} + +// Type returns the type of response this IDPResp represents +func (a *IDPResp) Type() string { + return a.resultType +} + +// AuthResult returns the AuthResult if present, or an empty AuthResult if not set +// Use HasAuthResult() to check if the value is actually set +func (a *IDPResp) AuthResult() public.AuthResult { + if a.authResultVal == nil { + return public.AuthResult{} + } + return *a.authResultVal +} + +// AccessToken returns the AccessToken if present, or an empty AccessToken if not set +// Use HasAccessToken() to check if the value is actually set +func (a *IDPResp) AccessToken() azcore.AccessToken { + if a.accessTokenVal == nil { + return azcore.AccessToken{} + } + return *a.accessTokenVal +} + +// RawToken returns the raw token string +func (a *IDPResp) RawToken() string { + return a.rawTokenVal +} diff --git a/internal/idp_response_test.go b/internal/idp_response_test.go new file mode 100644 index 0000000..7b98226 --- /dev/null +++ b/internal/idp_response_test.go @@ -0,0 +1,354 @@ +package internal + +import ( + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/assert" +) + +func TestIDPResp_Type(t *testing.T) { + tests := []struct { + name string + resultType string + want string + }{ + { + name: "AuthResult type", + resultType: "AuthResult", + want: "AuthResult", + }, + { + name: "Empty type", + resultType: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + resultType: tt.resultType, + } + if got := resp.Type(); got != tt.want { + t.Errorf("IDPResp.Type() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIDPResp_AuthResult(t *testing.T) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + + tests := []struct { + name string + authResult *public.AuthResult + wantToken string + wantExpiresOn time.Time + }{ + { + name: "With AuthResult", + authResult: authResult, + wantToken: "test-token", + wantExpiresOn: now, + }, + { + name: "Nil AuthResult", + authResult: nil, + wantToken: "", + wantExpiresOn: time.Time{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + authResultVal: tt.authResult, + } + got := resp.AuthResult() + if got.AccessToken != tt.wantToken { + t.Errorf("IDPResp.AuthResult().AccessToken = %v, want %v", got.AccessToken, tt.wantToken) + } + if !got.ExpiresOn.Equal(tt.wantExpiresOn) { + t.Errorf("IDPResp.AuthResult().ExpiresOn = %v, want %v", got.ExpiresOn, tt.wantExpiresOn) + } + }) + } +} + +func TestIDPResp_AccessToken(t *testing.T) { + now := time.Now() + accessToken := &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: now, + } + + tests := []struct { + name string + accessToken *azcore.AccessToken + wantToken string + wantExpiresOn time.Time + }{ + { + name: "With AccessToken", + accessToken: accessToken, + wantToken: "test-token", + wantExpiresOn: now, + }, + { + name: "Nil AccessToken", + accessToken: nil, + wantToken: "", + wantExpiresOn: time.Time{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + accessTokenVal: tt.accessToken, + } + got := resp.AccessToken() + if got.Token != tt.wantToken { + t.Errorf("IDPResp.AccessToken().Token = %v, want %v", got.Token, tt.wantToken) + } + if !got.ExpiresOn.Equal(tt.wantExpiresOn) { + t.Errorf("IDPResp.AccessToken().ExpiresOn = %v, want %v", got.ExpiresOn, tt.wantExpiresOn) + } + }) + } +} + +func TestIDPResp_RawToken(t *testing.T) { + tests := []struct { + name string + rawToken string + want string + }{ + { + name: "With RawToken", + rawToken: "test-raw-token", + want: "test-raw-token", + }, + { + name: "Empty RawToken", + rawToken: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + rawTokenVal: tt.rawToken, + } + if got := resp.RawToken(); got != tt.want { + t.Errorf("IDPResp.RawToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewIDPResp(t *testing.T) { + tests := []struct { + name string + resultType string + result interface{} + wantErr bool + checkResult func(t *testing.T, resp *IDPResp) + }{ + { + name: "valid AuthResult pointer", + resultType: "AuthResult", + result: &public.AuthResult{ + AccessToken: "test-token", + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + }, + }, + { + name: "valid AuthResult value", + resultType: "AuthResult", + result: public.AuthResult{ + AccessToken: "test-token", + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + }, + }, + { + name: "valid AccessToken pointer", + resultType: "AccessToken", + result: &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: time.Now(), + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.AccessToken().Token) + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "valid AccessToken value", + resultType: "AccessToken", + result: azcore.AccessToken{ + Token: "test-token", + ExpiresOn: time.Now(), + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.AccessToken().Token) + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "valid RawToken string", + resultType: "RawToken", + result: "test-token", + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "valid RawToken string pointer", + resultType: "RawToken", + result: stringPtr("test-token"), + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "nil result", + resultType: "AuthResult", + result: nil, + wantErr: true, + }, + { + name: "nil RawToken pointer", + resultType: "RawToken", + result: (*string)(nil), + wantErr: true, + }, + { + name: "invalid AuthResult type", + resultType: "AuthResult", + result: "not-an-auth-result", + wantErr: true, + }, + { + name: "invalid AccessToken type", + resultType: "AccessToken", + result: "not-an-access-token", + wantErr: true, + }, + { + name: "invalid RawToken type", + resultType: "RawToken", + result: 123, + wantErr: true, + }, + { + name: "unsupported result type", + resultType: "InvalidType", + result: "test", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewIDPResp(tt.resultType, tt.result) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, tt.resultType, got.Type()) + + if tt.checkResult != nil { + tt.checkResult(t, got) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} + +func BenchmarkIDPResp_Type(b *testing.B) { + resp := &IDPResp{ + resultType: "AuthResult", + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.Type() + } +} + +func BenchmarkIDPResp_AuthResult(b *testing.B) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + resp := &IDPResp{ + authResultVal: authResult, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.AuthResult() + } +} + +func BenchmarkIDPResp_AccessToken(b *testing.B) { + now := time.Now() + accessToken := &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: now, + } + resp := &IDPResp{ + accessTokenVal: accessToken, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.AccessToken() + } +} + +func BenchmarkIDPResp_RawToken(b *testing.B) { + resp := &IDPResp{ + rawTokenVal: "test-raw-token", + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.RawToken() + } +} + +func BenchmarkNewIDPResp(b *testing.B) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = NewIDPResp("AuthResult", authResult) + } +} diff --git a/internal/utils.go b/internal/utils.go new file mode 100644 index 0000000..ba82f1c --- /dev/null +++ b/internal/utils.go @@ -0,0 +1,12 @@ +package internal + +// IsClosed checks if a channel is closed. +func IsClosed(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + } + + return false +} diff --git a/internal/utils_test.go b/internal/utils_test.go new file mode 100644 index 0000000..e80ccb0 --- /dev/null +++ b/internal/utils_test.go @@ -0,0 +1,58 @@ +package internal + +import "testing" + +func TestIsClosedWithNilChannel(t *testing.T) { + t.Parallel() + var ch chan struct{} + if IsClosed(ch) { + t.Error("expected nil channel to be open") + } +} + +func TestIsClosedWithEmptyChannel(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) + if IsClosed(ch) { + t.Error("expected empty channel to be open") + } + + close(ch) + if !IsClosed(ch) { + t.Error("expected empty channel to be closed") + } +} + +func TestIsClosedWithClosedChannel(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) + close(ch) + if !IsClosed(ch) { + t.Error("expected closed channel to be closed") + } +} + +func BenchmarkIsClosedWithNilChannel(b *testing.B) { + var ch chan struct{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} + +func BenchmarkIsClosedWithEmptyChannel(b *testing.B) { + ch := make(chan struct{}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} + +func BenchmarkIsClosedWithClosedChannel(b *testing.B) { + ch := make(chan struct{}) + close(ch) + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} diff --git a/manager/defaults.go b/manager/defaults.go new file mode 100644 index 0000000..992b878 --- /dev/null +++ b/manager/defaults.go @@ -0,0 +1,180 @@ +package manager + +import ( + "errors" + "fmt" + "net" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" +) + +const ( + DefaultExpirationRefreshRatio = 0.7 + DefaultRetryOptionsMaxAttempts = 3 + DefaultRetryOptionsBackoffMultiplier = 2.0 + DefaultRetryOptionsInitialDelay = 1000 * time.Millisecond + DefaultRetryOptionsMaxDelay = 10000 * time.Millisecond +) + +// defaultIsRetryable is a function that checks if the error is retriable. +// It takes an error as an argument and returns a boolean value. +// The function checks if the error is a net.Error and if it is a timeout or temporary error. +// Returns true for nil errors. +var defaultIsRetryable = func(err error) bool { + if err == nil { + return true + } + + var netErr net.Error + if errors.As(err, &netErr) { + // Check for timeout first as it's more specific + if netErr.Timeout() { + return true + } + // For temporary errors, we'll use a more modern approach + var tempErr interface{ Temporary() bool } + if errors.As(err, &tempErr) { + return tempErr.Temporary() + } + } + + return errors.Is(err, os.ErrDeadlineExceeded) +} + +// defaultRetryOptionsOr returns the default retry options if the provided options are not set. +// It sets the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. +// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. +// The values can be overridden by the user. +func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions { + if retryOptions.IsRetryable == nil { + retryOptions.IsRetryable = defaultIsRetryable + } + + if retryOptions.MaxAttempts <= 0 { + retryOptions.MaxAttempts = DefaultRetryOptionsMaxAttempts + } + if retryOptions.InitialDelay == 0 { + retryOptions.InitialDelay = DefaultRetryOptionsInitialDelay + } + if retryOptions.BackoffMultiplier == 0 { + retryOptions.BackoffMultiplier = DefaultRetryOptionsBackoffMultiplier + } + if retryOptions.MaxDelay == 0 { + retryOptions.MaxDelay = DefaultRetryOptionsMaxDelay + } + return retryOptions +} + +// defaultIdentityProviderResponseParserOr returns the default token parser if the provided token parser is not set. +// It sets the default token parser to the defaultIdentityProviderResponseParser function. +// The default token parser is used to parse the raw token and return a Token object. +func defaultIdentityProviderResponseParserOr(idpResponseParser shared.IdentityProviderResponseParser) shared.IdentityProviderResponseParser { + if idpResponseParser == nil { + return entraidIdentityProviderResponseParser + } + return idpResponseParser +} + +func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptions { + options.RetryOptions = defaultRetryOptionsOr(options.RetryOptions) + options.IdentityProviderResponseParser = defaultIdentityProviderResponseParserOr(options.IdentityProviderResponseParser) + if options.ExpirationRefreshRatio == 0 { + options.ExpirationRefreshRatio = DefaultExpirationRefreshRatio + } + return options +} + +type defaultIdentityProviderResponseParser struct{} + +// ParseResponse parses the response from the identity provider and extracts the token. +// It takes an IdentityProviderResponse as an argument and returns a Token and an error if any. +// The IdentityProviderResponse contains the raw token and the expiration time. +func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) { + if response == nil { + return nil, fmt.Errorf("identity provider response cannot be nil") + } + + var username, password, rawToken string + var expiresOn time.Time + now := time.Now().UTC() + + switch response.Type() { + case shared.ResponseTypeAuthResult: + authResult := response.(shared.AuthResultIDPResponse).AuthResult() + if authResult.ExpiresOn.IsZero() { + return nil, fmt.Errorf("auth result expiration time is not set") + } + if authResult.IDToken.Oid == "" { + return nil, fmt.Errorf("auth result OID is empty") + } + rawToken = authResult.IDToken.RawToken + username = authResult.IDToken.Oid + password = rawToken + expiresOn = authResult.ExpiresOn.UTC() + + case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: + tokenStr := response.(shared.RawTokenIDPResponse).RawToken() + + if response.Type() == shared.ResponseTypeAccessToken { + accessToken := response.(shared.AccessTokenIDPResponse).AccessToken() + if accessToken.Token == "" { + return nil, fmt.Errorf("access token value is empty") + } + tokenStr = accessToken.Token + expiresOn = accessToken.ExpiresOn.UTC() + } + + if tokenStr == "" { + return nil, fmt.Errorf("raw token is empty") + } + + claims := struct { + jwt.RegisteredClaims + Oid string `json:"oid,omitempty"` + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) + } + + if claims.Oid == "" { + return nil, fmt.Errorf("JWT token does not contain OID claim") + } + + rawToken = tokenStr + username = claims.Oid + password = rawToken + + if expiresOn.IsZero() && claims.ExpiresAt != nil { + expiresOn = claims.ExpiresAt.UTC() + } + + default: + return nil, fmt.Errorf("unsupported response type: %s", response.Type()) + } + + if expiresOn.IsZero() { + return nil, fmt.Errorf("token expiration time is not set") + } + + if expiresOn.Before(now) { + return nil, fmt.Errorf("token has expired at %s (current time: %s)", expiresOn, now) + } + + // Create the token with consistent time reference + return token.New( + username, + password, + rawToken, + expiresOn, + now, + int64(time.Until(expiresOn).Seconds()), + ), nil +} diff --git a/manager/errors.go b/manager/errors.go new file mode 100644 index 0000000..dac0e9a --- /dev/null +++ b/manager/errors.go @@ -0,0 +1,9 @@ +package manager + +import "fmt" + +// ErrTokenManagerAlreadyStopped is returned when the token manager is already stopped. +var ErrTokenManagerAlreadyStopped = fmt.Errorf("token manager already stopped") + +// ErrTokenManagerAlreadyStarted is returned when the token manager is already started. +var ErrTokenManagerAlreadyStarted = fmt.Errorf("token manager already started") diff --git a/manager/manager_test.go b/manager/manager_test.go new file mode 100644 index 0000000..55c85d6 --- /dev/null +++ b/manager/manager_test.go @@ -0,0 +1,181 @@ +package manager + +import ( + "context" + "net" + "os" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/mock" +) + +// testJWTToken is a JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1743515011, +// "exp": 1775051011, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0MzUxNTAxMSwiZXhwIjoxNzc1MDUxMDExLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.6RG721V2eFlSLsCRmo53kSRRrTZIe1UPdLZCUEvIarU" + +// testJWTExpiredToken is an expired JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1617795148, +// "exp": 1617795148, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTExpiredToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTYxNzc5NTE0OCwiZXhwIjoxNjE3Nzk1MTQ4LCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.IbGPhHRiPYcpUDrhAPf4h3gH1XXBOu560NYT59rUMzc" + +// testJWTWithZeroExpiryToken is a JWT token with zero expiry for testing +// +// { +// "iss": "test jwt", +// "iat": 1744025944, +// "exp": null, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTWithZeroExpiryToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0NDAyNTk0NCwiZXhwIjpudWxsLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.bLSANIzawE5Y6rgspvvUaRhkBq6Y4E0ggjXlmHRn8ew" + +var testTokenValid = token.New( + "test", + "password", + "test", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), +) + +type mockIdentityProviderResponseParser struct { + // Mock implementation of the IdentityProviderResponseParser interface + mock.Mock +} + +func (m *mockIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) { + args := m.Called(response) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*token.Token), args.Error(1) +} + +type mockIdentityProvider struct { + // Mock implementation of the mockIdentityProvider interface + // Add any necessary fields or methods for the mock identity provider here + mock.Mock +} + +func (m *mockIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(shared.IdentityProviderResponse), args.Error(1) +} + +func newMockError(retriable bool) error { + if retriable { + return &mockError{ + isTimeout: true, + isTemporary: true, + error: os.ErrDeadlineExceeded, + } + } else { + return &mockError{ + isTimeout: false, + isTemporary: false, + error: os.ErrInvalid, + } + } +} + +type mockError struct { + // Mock implementation of the network error + error + isTimeout bool + isTemporary bool +} + +func (m *mockError) Error() string { + return "this is mock error" +} + +func (m *mockError) Timeout() bool { + return m.isTimeout +} +func (m *mockError) Temporary() bool { + return m.isTemporary +} +func (m *mockError) Unwrap() error { + return m.error +} + +func (m *mockError) Is(err error) bool { + return m.error == err +} + +var _ net.Error = (*mockError)(nil) + +type mockTokenListener struct { + // Mock implementation of the TokenManagerListener interface + mock.Mock + Id int32 +} + +func (m *mockTokenListener) OnNext(token *token.Token) { + _ = m.Called(token) +} + +func (m *mockTokenListener) OnError(err error) { + _ = m.Called(err) +} + +type authResult struct { + // ResultType is the type of the response (AuthResult, AccessToken, or RawToken) + ResultType string + // AuthResultVal is the auth result value + AuthResultVal *public.AuthResult + // AccessTokenVal is the access token value + AccessTokenVal *azcore.AccessToken + // RawTokenVal is the raw token value + RawTokenVal string +} + +func (a *authResult) Type() string { + return a.ResultType +} + +func (a *authResult) AuthResult() public.AuthResult { + if a.AuthResultVal == nil { + return public.AuthResult{} + } + return *a.AuthResultVal +} + +func (a *authResult) AccessToken() azcore.AccessToken { + if a.AccessTokenVal == nil { + return azcore.AccessToken{} + } + return *a.AccessTokenVal +} + +func (a *authResult) RawToken() string { + return a.RawTokenVal +} diff --git a/manager/token_manager.go b/manager/token_manager.go new file mode 100644 index 0000000..4dc2c33 --- /dev/null +++ b/manager/token_manager.go @@ -0,0 +1,378 @@ +package manager + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/redis-developer/go-redis-entraid/internal" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" +) + +// TokenManagerOptions is a struct that contains the options for the TokenManager. +type TokenManagerOptions struct { + // ExpirationRefreshRatio is the ratio of the token expiration time to refresh the token. + // It is used to determine when to refresh the token. + // The value should be between 0 and 1. + // For example, if the expiration time is 1 hour and the ratio is 0.75, + // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + // + // default: 0.7 + ExpirationRefreshRatio float64 + // LowerRefreshBound is the lower bound for the refresh time + // Represents the minimum time before token expiration to trigger a refresh. + // This value sets a fixed lower bound for when a token refresh should occur, regardless + // of the token's total lifetime. + // + // default: 0 (no lower bound, refresh based on ExpirationRefreshRatio) + LowerRefreshBound time.Duration + + // IdentityProviderResponseParser is an optional object that implements the IdentityProviderResponseParser interface. + // It is used to parse the response from the identity provider and extract the token. + // If not provided, the default implementation will be used. + // The objects ParseResponse method will be called to parse the response and return the token. + // + // required: false + // default: defaultIdentityProviderResponseParser + IdentityProviderResponseParser shared.IdentityProviderResponseParser + // RetryOptions is a struct that contains the options for retrying the token request. + // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. + // + // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. + RetryOptions RetryOptions + + // RequestTimeout is the timeout for the request to the identity provider. + RequestTimeout time.Duration +} + +// RetryOptions is a struct that contains the options for retrying the token request. +type RetryOptions struct { + // IsRetryable is a function that checks if the error is retriable. + // It takes an error as an argument and returns a boolean value. + // + // default: defaultRetryableFunc + IsRetryable func(err error) bool + // MaxAttempts is the maximum number of attempts to retry the token request. + // + // default: 3 + MaxAttempts int + // InitialDelay is the initial delay before retrying the token request. + // + // default: 1 second + InitialDelay time.Duration + // MaxDelay is the maximum delay between retry attempts. + // + // default: 10 seconds + MaxDelay time.Duration + // BackoffMultiplier is the multiplier for the backoff delay. + // default: 2.0 + BackoffMultiplier float64 +} + +// TokenManager is an interface that defines the methods for managing tokens. +// It provides methods to get a token and start the token manager. +// The TokenManager is responsible for obtaining and refreshing the token. +// It is typically used in conjunction with an IdentityProvider to obtain the token. +type TokenManager interface { + // GetToken returns the token for authentication. + // It takes a boolean value forceRefresh as an argument. + GetToken(forceRefresh bool) (*token.Token, error) + // Start starts the token manager and returns a channel that will receive updates. + Start(listener TokenListener) (StopFunc, error) + // Stop stops the token manager and releases any resources. + Stop() error +} + +// StopFunc is a function that stops the token manager. +type StopFunc func() error + +// TokenListener is an interface that contains the methods for receiving updates from the token manager. +// The token manager will call the listener's OnTokenNext method with the updated token. +// If an error occurs, the token manager will call the listener's OnTokenError method with the error. +type TokenListener interface { + // OnNext is called when the token is updated. + OnNext(t *token.Token) + // OnError is called when an error occurs. + OnError(err error) +} + +// entraidIdentityProviderResponseParser is the default implementation of the IdentityProviderResponseParser interface. +var entraidIdentityProviderResponseParser shared.IdentityProviderResponseParser = &defaultIdentityProviderResponseParser{} + +// NewTokenManager creates a new TokenManager. +// It takes an IdentityProvider and TokenManagerOptions as arguments and returns a TokenManager interface. +// The IdentityProvider is used to obtain the token, and the TokenManagerOptions contains options for the TokenManager. +// The TokenManager is responsible for managing the token and refreshing it when necessary. +func NewTokenManager(idp shared.IdentityProvider, options TokenManagerOptions) (TokenManager, error) { + if options.ExpirationRefreshRatio < 0 || options.ExpirationRefreshRatio > 1 { + return nil, fmt.Errorf("expiration refresh ratio must be between 0 and 1") + } + options = defaultTokenManagerOptionsOr(options) + + if idp == nil { + return nil, fmt.Errorf("identity provider is required") + } + + ctx, ctxCancel := context.WithCancel(context.Background()) + return &entraidTokenManager{ + idp: idp, + token: nil, + closedChan: nil, + ctx: ctx, + ctxCancel: ctxCancel, + expirationRefreshRatio: options.ExpirationRefreshRatio, + lowerBoundDuration: options.LowerRefreshBound, + identityProviderResponseParser: options.IdentityProviderResponseParser, + retryOptions: options.RetryOptions, + requestTimeout: options.RequestTimeout, + }, nil +} + +// entraidTokenManager is a struct that implements the TokenManager interface. +type entraidTokenManager struct { + // idp is the identity provider used to obtain the token. + idp shared.IdentityProvider + + // token is the authentication token for the user which should be kept in memory if valid. + token *token.Token + + // tokenRWLock is a read-write lock used to protect the token from concurrent access. + tokenRWLock sync.RWMutex + + // identityProviderResponseParser is the parser used to parse the response from the identity provider. + // It`s ParseResponse method will be called to parse the response and return the token. + identityProviderResponseParser shared.IdentityProviderResponseParser + + // retryOptions is a struct that contains the options for retrying the token request. + // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. + // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. + // The values can be overridden by the user. + retryOptions RetryOptions + + // listener is the single listener for the token manager. + // It is used to receive updates from the token manager. + // The token manager will call the listener's OnNext method with the updated token. + // If an error occurs, the token manager will call the listener's OnError method with the error. + // if listener is set, Start will fail + listener TokenListener + + // lock locks the listener to prevent concurrent access. + lock sync.Mutex + + // expirationRefreshRatio is the ratio of the token expiration time to refresh the token. + // It is used to determine when to refresh the token. + // The value should be between 0 and 1. + // For example, if the expiration time is 1 hour and the ratio is 0.75, + // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + expirationRefreshRatio float64 + + // lowerBoundDuration is the lower bound for the refresh time in time.Duration. + lowerBoundDuration time.Duration + + // closedChan is a channel that is closedChan when the token manager is closedChan. + // It is used to signal the token manager to stop requesting tokens. + closedChan chan struct{} + + // context is the context used to request the token from the identity provider. + ctx context.Context + + // ctxCancel is the cancel function for the context. + ctxCancel context.CancelFunc + + // requestTimeout is the timeout for the request to the identity provider. + requestTimeout time.Duration +} + +func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + e.tokenRWLock.RLock() + // check if the token is nil and if it is not expired + + if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { + t := e.token + e.tokenRWLock.RUnlock() + return t, nil + } + e.tokenRWLock.RUnlock() + + // start the context early, + // since at heavy concurrent load + // locks may take some time to acquire + ctx, ctxCancel := context.WithTimeout(e.ctx, e.requestTimeout) + defer ctxCancel() + + // Upgrade to write lock for token update + e.tokenRWLock.Lock() + defer e.tokenRWLock.Unlock() + + // Double-check pattern to avoid unnecessary token refresh + if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { + return e.token, nil + } + + // Request a new token from the identity provider + idpResult, err := e.idp.RequestToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to request token from idp: %w", err) + } + + t, err := e.identityProviderResponseParser.ParseResponse(idpResult) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + if t == nil { + return nil, fmt.Errorf("failed to get token: token is nil") + } + + // Store the token + e.token = t + // Return the token - no need to copy since it's immutable + return t, nil +} + +// Start starts the token manager and returns cancelFunc to stop the token manager. +// It takes a TokenListener as an argument, which is used to receive updates. +// The token manager will call the listener's OnNext method with the updated token. +// If an error occurs, the token manager will call the listener's OnError method with the error. +// +// Note: The initial token is delivered synchronously. +// The TokenListener will receive the token immediately, before the token manager goroutine starts. +func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { + e.lock.Lock() + defer e.lock.Unlock() + if e.listener != nil { + return nil, ErrTokenManagerAlreadyStarted + } + + if e.closedChan != nil && !internal.IsClosed(e.closedChan) { + // there is a hanging goroutine that is waiting for the closedChan to be closed + // if the closedChan is not nil and not closed, close it + close(e.closedChan) + } + + t, err := e.GetToken(true) + if err != nil { + go listener.OnError(err) + return nil, fmt.Errorf("failed to start token manager: %w", err) + } + + // Deliver initial token synchronously + listener.OnNext(t) + + e.closedChan = make(chan struct{}) + e.listener = listener + + go func(listener TokenListener, closed <-chan struct{}) { + maxDelay := e.retryOptions.MaxDelay + initialDelay := e.retryOptions.InitialDelay + + for { + timeToRenewal := e.durationToRenewal() + select { + case <-closed: + return + case <-time.After(timeToRenewal): + if timeToRenewal == 0 { + // Token was requested immediately, guard against infinite loop + select { + case <-closed: + return + case <-time.After(initialDelay): + // continue to attempt + } + } + + // Token is about to expire, refresh it + delay := initialDelay + for i := 0; i < e.retryOptions.MaxAttempts; i++ { + t, err := e.GetToken(true) + if err == nil { + listener.OnNext(t) + break + } + + // check if err is retriable + if e.retryOptions.IsRetryable(err) { + if i == e.retryOptions.MaxAttempts-1 { + // last attempt, call OnError + listener.OnError(fmt.Errorf("max attempts reached: %w", err)) + return + } + + // Exponential backoff + if delay < maxDelay { + delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier) + } + if delay > maxDelay { + delay = maxDelay + } + + select { + case <-closed: + return + case <-time.After(delay): + // continue to next attempt + } + } else { + // not retriable + listener.OnError(err) + return + } + } + } + } + }(listener, e.closedChan) + + return e.Stop, nil +} + +// Stop closes the token manager and releases any resources. +func (e *entraidTokenManager) Stop() error { + e.lock.Lock() + defer e.lock.Unlock() + + if e.closedChan == nil || e.listener == nil { + return ErrTokenManagerAlreadyStopped + } + + e.ctxCancel() + e.listener = nil + close(e.closedChan) + + return nil +} + +// durationToRenewal calculates the duration to the next token renewal. +// It returns the duration to the next token renewal based on the expiration refresh ratio and the lower bound duration. +// If the token is nil, it returns 0. +// If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now. +func (e *entraidTokenManager) durationToRenewal() time.Duration { + e.tokenRWLock.RLock() + if e.token == nil { + e.tokenRWLock.RUnlock() + return 0 + } + + timeTillExpiration := time.Until(e.token.ExpirationOn()) + e.tokenRWLock.RUnlock() + + // if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW + if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 { + return 0 + } + + // Calculate the time to renew the token based on the expiration refresh ratio + // Since timeTillExpiration is guarded by the lower bound, we can safely multiply it by the ratio + // and assume the duration is a positive number + duration := time.Duration(float64(timeTillExpiration) * e.expirationRefreshRatio) + + // if the duration will take us past the lower bound, return the duration to lower bound + if timeTillExpiration-e.lowerBoundDuration < duration { + return timeTillExpiration - e.lowerBoundDuration + } + + // return the calculated duration + return duration +} diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go new file mode 100644 index 0000000..2fb2e97 --- /dev/null +++ b/manager/token_manager_test.go @@ -0,0 +1,1677 @@ +package manager + +import ( + "context" + "fmt" + "log" + "math/rand" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var assertFuncNameMatches = func(t *testing.T, func1, func2 interface{}) { + funcName1 := runtime.FuncForPC(reflect.ValueOf(func1).Pointer()).Name() + funcName2 := runtime.FuncForPC(reflect.ValueOf(func2).Pointer()).Name() + assert.Equal(t, funcName1, funcName2) +} + +func TestTokenManager(t *testing.T) { + t.Parallel() + t.Run("Without IDP", func(t *testing.T) { + t.Parallel() + tokenManager, err := NewTokenManager(nil, + TokenManagerOptions{}, + ) + assert.Error(t, err) + assert.Nil(t, tokenManager) + }) + + t.Run("With IDP", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + }) +} + +func TestTokenManagerWithOptions(t *testing.T) { + t.Parallel() + t.Run("Bad Expiration Refresh Ration", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{ + ExpirationRefreshRatio: 5, + } + tokenManager, err := NewTokenManager(idp, options) + assert.Error(t, err) + assert.Nil(t, tokenManager) + }) + t.Run("With IDP and Options", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{ + ExpirationRefreshRatio: 0.5, + } + tokenManager, err := NewTokenManager(idp, options) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Equal(t, 0.5, tm.expirationRefreshRatio) + }) + t.Run("Default Options", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{} + tokenManager, err := NewTokenManager(idp, options) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Equal(t, DefaultExpirationRefreshRatio, tm.expirationRefreshRatio) + assert.NotNil(t, tm.retryOptions.IsRetryable) + assertFuncNameMatches(t, tm.retryOptions.IsRetryable, defaultIsRetryable) + assert.Equal(t, DefaultRetryOptionsMaxAttempts, tm.retryOptions.MaxAttempts) + assert.Equal(t, DefaultRetryOptionsInitialDelay, tm.retryOptions.InitialDelay) + assert.Equal(t, DefaultRetryOptionsMaxDelay, tm.retryOptions.MaxDelay) + assert.Equal(t, DefaultRetryOptionsBackoffMultiplier, tm.retryOptions.BackoffMultiplier) + }) +} + +func TestDefaultIdentityProviderResponseParserOr(t *testing.T) { + t.Parallel() + var f shared.IdentityProviderResponseParser = &mockIdentityProviderResponseParser{} + + result := defaultIdentityProviderResponseParserOr(f) + assert.NotNil(t, result) + assert.Equal(t, result, f) + + defaultParser := defaultIdentityProviderResponseParserOr(nil) + assert.NotNil(t, defaultParser) + assert.NotEqual(t, defaultParser, f) + assert.Equal(t, entraidIdentityProviderResponseParser, defaultParser) +} + +func TestDefaultIsRetryable(t *testing.T) { + t.Parallel() + // with network error timeout + t.Run("Non-Retryable Error", func(t *testing.T) { + t.Parallel() + err := &azcore.ResponseError{ + StatusCode: 500, + } + is := defaultIsRetryable(err) + assert.False(t, is) + }) + + t.Run("Nil Error", func(t *testing.T) { + t.Parallel() + var err error + is := defaultIsRetryable(err) + assert.True(t, is) + + is = defaultIsRetryable(nil) + assert.True(t, is) + }) + + t.Run("Retryable Error with Timeout", func(t *testing.T) { + t.Parallel() + err := newMockError(true) + result := defaultIsRetryable(err) + assert.True(t, result) + }) + t.Run("Retryable Error with Temporary", func(t *testing.T) { + t.Parallel() + err := newMockError(true) + result := defaultIsRetryable(err) + assert.True(t, result) + }) + + t.Run("Retryable Error with err parent of os.ErrDeadlineExceeded", func(t *testing.T) { + t.Parallel() + err := fmt.Errorf("timeout: %w", os.ErrDeadlineExceeded) + res := defaultIsRetryable(err) + assert.True(t, res) + }) +} + +func TestTokenManager_Close(t *testing.T) { + t.Parallel() + t.Run("Close", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + assert.NotPanics(t, func() { + err = tokenManager.Stop() + assert.Error(t, err) + }) + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + }) + assert.NotNil(t, tm.listener) + + err = tokenManager.Stop() + assert.Nil(t, tm.listener) + assert.NoError(t, err) + + assert.NotPanics(t, func() { + err = tokenManager.Stop() + assert.Error(t, err) + }) + }) + + t.Run("Close with Cancel", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + err = cancel() + assert.NoError(t, err) + assert.Nil(t, tm.listener) + err = cancel() + assert.Error(t, err) + assert.Nil(t, tm.listener) + }) + }) + t.Run("Close in multiple threads", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + var hasStopped int + var alreadyStopped int32 + wg := &sync.WaitGroup{} + + // Start 50000 goroutines to close the token manager + // and check if the listener is nil after each close. + numExecutions := 50000 + for i := 0; i < numExecutions; i++ { + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) + err := tokenManager.Stop() + if err == nil { + hasStopped += 1 + return + } else { + atomic.AddInt32(&alreadyStopped, 1) + } + assert.Nil(t, tm.listener) + assert.Error(t, err) + }() + } + wg.Wait() + assert.Nil(t, tm.listener) + assert.Equal(t, 1, hasStopped) + assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStopped)) + }) + }) +} + +func TestTokenManager_Start(t *testing.T) { + t.Parallel() + t.Run("Start in multiple threads", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + var hasStarted int + var alreadyStarted int32 + wg := &sync.WaitGroup{} + + numExecutions := 50000 + for i := 0; i < numExecutions; i++ { + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) + _, err := tokenManager.Start(listener) + if err == nil { + hasStarted += 1 + return + } else { + atomic.AddInt32(&alreadyStarted, 1) + } + assert.NotNil(t, tm.listener) + assert.Error(t, err) + }() + } + wg.Wait() + assert.NotNil(t, tm.listener) + assert.Equal(t, 1, hasStarted) + assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStarted)) + cancel, err := tokenManager.Start(listener) + assert.Nil(t, cancel) + assert.Error(t, err) + assert.NotNil(t, tm.listener) + }) + }) + + t.Run("concurrent stress token manager", func(t *testing.T) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + assert.NotPanics(t, func() { + last := &atomic.Int32{} + wg := &sync.WaitGroup{} + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + numExecutions := int32(50000) + for i := int32(0); i < numExecutions; i++ { + wg.Add(1) + go func(num int32) { + defer wg.Done() + var err error + time.Sleep(time.Duration(int64(rand.Intn(1000)+(300-int(num)/2)) * int64(time.Millisecond))) + last.Store(num) + if num%2 == 0 { + err = tokenManager.Stop() + } else { + l := &mockTokenListener{Id: num} + l.On("OnNext", testTokenValid).Return() + _, err = tokenManager.Start(l) + } + if err != nil { + if err != ErrTokenManagerAlreadyStopped && err != ErrTokenManagerAlreadyStarted { + // this is un unexpected error, fail the test + assert.Error(t, err) + } + } + }(i) + } + wg.Wait() + lastExecution := last.Load() + if lastExecution%2 == 0 { + if tm.listener != nil { + l := tm.listener.(*mockTokenListener) + log.Printf("FAILING WITH lastExecution [STARTED]:[LISTENER:%d]: %d", l.Id, lastExecution) + } + assert.Nil(t, tm.listener) + } else { + if tm.listener == nil { + log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution) + } + assert.NotNil(t, tm.listener) + cancel, err := tokenManager.Start(listener) + assert.Nil(t, cancel) + assert.Error(t, err) + // Close the token manager + err = tokenManager.Stop() + assert.Nil(t, err) + } + assert.Nil(t, tm.listener) + }) + }) +} + +func TestDefaultIdentityProviderResponseParser(t *testing.T) { + t.Parallel() + parser := &defaultIdentityProviderResponseParser{} + t.Run("Default IdentityProviderResponseParser with type AuthResult", func(t *testing.T) { + t.Parallel() + authResultVal := testAuthResult(time.Now().Add(time.Hour).UTC()) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + assert.Equal(t, authResultVal.ExpiresOn, token1.ExpirationOn()) + }) + t.Run("Default IdentityProviderResponseParser with type AccessToken", func(t *testing.T) { + t.Parallel() + accessToken := &azcore.AccessToken{ + Token: testJWTToken, + ExpiresOn: time.Now().Add(time.Hour).UTC(), + } + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAccessToken, + AccessTokenVal: accessToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + assert.Equal(t, accessToken.ExpiresOn, token1.ExpirationOn()) + assert.Equal(t, accessToken.Token, token1.RawCredentials()) + }) + t.Run("Default IdentityProviderResponseParser with type RawToken", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with expired JWT Token", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTExpiredToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with zero expiry JWT Token", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTWithZeroExpiryToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with type Unknown", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: "Unknown", + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + types := []string{ + shared.ResponseTypeAuthResult, + shared.ResponseTypeAccessToken, + shared.ResponseTypeRawToken, + } + for _, rt := range types { + t.Run(fmt.Sprintf("Default IdentityProviderResponseParser with response type %s and nil value", rt), func(t *testing.T) { + idpResponse := &authResult{ + ResultType: rt, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + } + + t.Run("Default IdentityProviderResponseParser with response nil", func(t *testing.T) { + t.Parallel() + token1, err := parser.ParseResponse(nil) + assert.Error(t, err) + assert.Nil(t, token1) + }) + t.Run("Default IdentityProviderResponseParser with expired token", func(t *testing.T) { + t.Parallel() + authResultVal := testAuthResult(time.Now().Add(-time.Hour)) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) +} + +func TestEntraidTokenManager_GetToken(t *testing.T) { + t.Parallel() + t.Run("GetToken", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + }) + + t.Run("GetToken with parse error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error")) + listener.On("OnError", mock.Anything).Return() + + cancel, err := tokenManager.Start(listener) + assert.Error(t, err) + assert.Nil(t, cancel) + assert.Nil(t, tm.listener) + }) + t.Run("GetToken with expired token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + + authResultVal := testAuthResult(time.Now().Add(-time.Hour)) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with nil token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with nil from parser", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) + mParser.On("ParseResponse", idpResponse).Return(nil, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with idp error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + idp.On("RequestToken", mock.Anything).Return(nil, fmt.Errorf("idp error")) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) +} + +func TestEntraidTokenManager_durationToRenewal(t *testing.T) { + t.Parallel() + t.Run("durationToRenewal", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ + LowerRefreshBound: time.Hour, + }) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + result := tm.durationToRenewal() + // returns 0 for nil token + assert.Equal(t, time.Duration(0), result) + + // get token that expires before the lower bound + assert.NotPanics(t, func() { + expiresSoon := testAuthResult(time.Now().Add(tm.lowerBoundDuration - time.Minute).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, + expiresSoon) + assert.NoError(t, err) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() + tm.token = nil + _, err = tm.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, tm.token) + + // return zero, should happen now since it expires before the lower bound + result = tm.durationToRenewal() + assert.Equal(t, time.Duration(0), result) + }) + + // get token that expires after the lower bound and expirationRefreshRatio to 1 + assert.NotPanics(t, func() { + tm.expirationRefreshRatio = 1 + expiresAfterlb := testAuthResult(time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, + expiresAfterlb) + assert.NoError(t, err) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() + tm.token = nil + _, err = tm.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, tm.token) + + // return time to lower bound, if the returned time will be after the lower bound + result = tm.durationToRenewal() + assert.InEpsilon(t, time.Until(tm.token.ExpirationOn().Add(-1*tm.lowerBoundDuration)), result, float64(time.Second)) + }) + + }) +} + +func TestEntraidTokenManager_Streaming(t *testing.T) { + t.Parallel() + t.Run("Start and Close", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + authResultVal := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() + token1 := token.New( + "test", + "test", + "test", + expiresOn, + time.Now(), + int64(time.Until(expiresOn)), + ) + + mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() + listener.On("OnNext", token1).Return().Once() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal / 10) + assert.NotNil(t, tm.listener) + assert.NoError(t, tokenManager.Stop()) + assert.Nil(t, tm.listener) + assert.Panics(t, func() { + close(tm.closedChan) + }) + + <-time.After(toRenewal) + assert.Error(t, tokenManager.Stop()) + mock.AssertExpectationsForObjects(t, idp, mParser, listener) + }) + + t.Run("Start and Listen with 0 renewal duration", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + LowerRefreshBound: time.Hour, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + done := make(chan struct{}) + var twice int32 + var start, stop time.Time + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + if atomic.LoadInt32(&twice) == 1 { + stop = time.Now() + close(done) + return + } else { + atomic.StoreInt32(&twice, 1) + start = time.Now() + } + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.Equal(t, time.Duration(0), toRenewal) + assert.True(t, expiresIn > toRenewal) + + // wait for request token to be called + <-done + // wait a bit for listener to be notified + <-time.After(10 * time.Millisecond) + assert.NoError(t, cancel()) + + assert.InDelta(t, stop.Sub(start), tm.retryOptions.InitialDelay, float64(200*time.Millisecond)) + + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnNext", 2) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with 0 renewal duration and closing the token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + LowerRefreshBound: time.Hour, + RetryOptions: RetryOptions{ + InitialDelay: 5 * time.Second, + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.Equal(t, time.Duration(0), toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(time.Duration(tm.retryOptions.InitialDelay / 2)) + assert.NoError(t, cancel()) + assert.Nil(t, tm.listener) + assert.Panics(t, func() { + close(tm.closedChan) + }) + + // called only once since the token manager was closed prior to initial delay passing + idp.AssertNumberOfCalls(t, "RequestToken", 1) + listener.AssertNumberOfCalls(t, "OnNext", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(toRenewal + time.Second) + + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with retriable error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0) + assert.NotNil(t, err) + }).Return().Maybe() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(true) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal + 100*time.Millisecond) + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnNext", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with NOT retriable error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + assert.NotNil(t, err) + }).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(false) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal + 100*time.Millisecond) + + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnNext", 1) + listener.AssertNumberOfCalls(t, "OnError", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with retriable error - max retries and max delay", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + maxAttempts := 3 + maxDelay := 500 * time.Millisecond + initialDelay := 100 * time.Millisecond + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + RetryOptions: RetryOptions{ + MaxAttempts: maxAttempts, + MaxDelay: maxDelay, + InitialDelay: initialDelay, + BackoffMultiplier: 10, + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + res.IDToken.Oid = "test" + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + res.IDToken.Oid = "test" + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + var start, end time.Time + var elapsed time.Duration + + _ = listener. + On("OnNext", mock.AnythingOfType("*token.Token")). + Run(func(_ mock.Arguments) { + start = time.Now() + }).Return() + maxAttemptsReached := make(chan struct{}) + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + end = time.Now() + elapsed = end.Sub(start) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "max attempts reached") + close(maxAttemptsReached) + }).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + noErrCall.Unset() + returnErr := newMockError(true) + + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) + + select { + case <-time.After(toRenewal + time.Duration(maxAttempts)*maxDelay): + assert.Fail(t, "Timeout - max retries not reached") + case <-maxAttemptsReached: + } + + // initialRenewal window, maxAttempts - 1 * max delay + the initial one which was lower than max delay + allDelaysShouldBe := toRenewal + allDelaysShouldBe += initialDelay + allDelaysShouldBe += time.Duration(maxAttempts-1) * maxDelay + + assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond)) + + idp.AssertNumberOfCalls(t, "RequestToken", tm.retryOptions.MaxAttempts+1) + listener.AssertNumberOfCalls(t, "OnNext", 1) + listener.AssertNumberOfCalls(t, "OnError", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + t.Run("Start and Listen and close during retries", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + RetryOptions: RetryOptions{ + MaxAttempts: 100, + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + maxAttemptsReached := make(chan struct{}) + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "max attempts reached") + close(maxAttemptsReached) + }).Return().Maybe() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(true) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(toRenewal + 500*time.Millisecond) + assert.Nil(t, cancel()) + + select { + case <-maxAttemptsReached: + assert.Fail(t, "Max retries reached, token manager not closed") + case <-tm.closedChan: + } + + <-time.After(50 * time.Millisecond) + + // maxAttempts + the initial one + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnError", 0) + mock.AssertExpectationsForObjects(t, idp, listener) + }) +} + +func testAuthResult(expiersOn time.Time) *public.AuthResult { + r := &public.AuthResult{ + ExpiresOn: expiersOn, + } + r.IDToken.Oid = "test" + return r +} + +func BenchmarkTokenManager_GetToken(b *testing.B) { + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = tokenManager.GetToken(false) + } +} + +func BenchmarkTokenManager_Start(b *testing.B) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = tokenManager.Start(listener) + } +} + +func BenchmarkTokenManager_Close(b *testing.B) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + _, err = tokenManager.Start(listener) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tokenManager.Stop() + } +} + +func BenchmarkTokenManager_durationToRenewal(b *testing.B) { + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ + LowerRefreshBound: time.Hour, + }) + if err != nil { + b.Fatal(err) + } + + tm, ok := tokenManager.(*entraidTokenManager) + if !ok { + b.Fatal("failed to cast to entraidTokenManager") + } + + expiresAfterlb := testAuthResult(time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, expiresAfterlb) + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) + _, err = tm.GetToken(false) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tm.durationToRenewal() + } +} + +// TestConcurrentTokenManagerOperations tests concurrent operations on the TokenManager +// to verify there are no deadlocks or race conditions in the implementation. +func TestConcurrentTokenManagerOperations(t *testing.T) { + t.Parallel() + + // Create a mock identity provider that returns predictable tokens + mockIdp := &concurrentMockIdentityProvider{ + tokenCounter: 0, + } + + // Create token manager with the mock provider + options := TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBound: 100 * time.Millisecond, + } + tm, err := NewTokenManager(mockIdp, options) + assert.NoError(t, err) + assert.NotNil(t, tm) + + // Number of concurrent operations to perform + const numConcurrentOps = 50 + const numGoroutines = 1000 + + // Channels to track received tokens and errors + tokenCh := make(chan *token.Token, numConcurrentOps*numGoroutines) + errorCh := make(chan error, numConcurrentOps*numGoroutines) + + // Channel to signal completion of all operations + doneCh := make(chan struct{}) + + // Track closers for cleanup + var closers sync.Map + + // Start multiple goroutines that will concurrently interact with the token manager + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(routineID int) { + defer wg.Done() + + for j := 0; j < numConcurrentOps; j++ { + // Create a listener for this operation + listener := &concurrentTestTokenListener{ + onNextFunc: func(t *token.Token) { + select { + case tokenCh <- t: + default: + // Channel full, ignore + } + }, + onErrorFunc: func(err error) { + select { + case errorCh <- err: + default: + // Channel full, ignore + } + }, + } + + // Choose operation based on a pattern + // Using modulo for a deterministic pattern that exercises all operations + opType := j % 3 + + // t.Logf("Goroutine %d, Operation %d: Performing operation type %d", routineID, j, opType) + + switch opType { + case 0: + // Start the token manager with a new listener + // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) + closeFunc, err := tm.Start(listener) + + if err != nil { + if err != ErrTokenManagerAlreadyStarted { + // t.Logf("Goroutine %d, Operation %d: Start failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to start token manager: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to start token manager: %v", routineID, j, err) + } + } + continue + } + + // t.Logf("Goroutine %d, Operation %d: Successfully started token manager", routineID, j) + // Store the closer for later cleanup + closerKey := fmt.Sprintf("closer-%d-%d", routineID, j) + closers.Store(closerKey, closeFunc) + + // Simulate some work + time.Sleep(time.Duration(500-rand.Intn(400)) * time.Millisecond) + + case 1: + // Get current token + //t.Logf("Goroutine %d, Operation %d: Getting token", routineID, j) + token, err := tm.GetToken(false) + if err != nil { + //t.Logf("Goroutine %d, Operation %d: GetToken failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to get token: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to get token: %v", routineID, j, err) + } + } else if token != nil { + //t.Logf("Goroutine %d, Operation %d: Successfully got token, expires: %v", routineID, j, token.ExpirationOn()) + select { + case tokenCh <- token: + default: + // Channel full, ignore + } + } + + case 2: + // Close a previously created token manager listener + // This simulates multiple subscriptions being created and destroyed + //t.Logf("Goroutine %d, Operation %d: Attempting to close a token manager", routineID, j) + closedAny := false + + closers.Range(func(key, value interface{}) bool { + if j%10 > 7 { // Only close some of the time based on a pattern + closedAny = true + //t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key) + + closeFunc := value.(StopFunc) + if err := closeFunc(); err != nil { + if err != ErrTokenManagerAlreadyStopped { + // t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to close token manager: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err) + } + } else { + //t.Logf("Goroutine %d, Operation %d: TokenManager was already stopped", routineID, j) + } + } else { + // t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j) + } + + closers.Delete(key) + return false // stop after finding one to close + } + return true + }) + + if !closedAny { + //t.Logf("Goroutine %d, Operation %d: No token manager to close or condition not met", routineID, j) + } + } + } + }(i) + } + + // Wait for all operations to complete or timeout + go func() { + wg.Wait() + close(doneCh) + }() + + // Use a timeout to detect deadlocks + select { + case <-doneCh: + // All operations completed successfully + t.Log("All concurrent operations completed successfully") + case <-time.After(30 * time.Second): + t.Fatal("test timed out, possible deadlock detected") + } + + // Count operations by type + var startCount, getTokenCount, closeCount int32 + + // Collect all ops from goroutines + for i := 0; i < numGoroutines; i++ { + for j := 0; j < numConcurrentOps; j++ { + opType := j % 3 + switch opType { + case 0: + atomic.AddInt32(&startCount, 1) + case 1: + atomic.AddInt32(&getTokenCount, 1) + case 2: + atomic.AddInt32(&closeCount, 1) + } + } + } + + // Clean up any remaining closers + closers.Range(func(key, value interface{}) bool { + closeFunc := value.(StopFunc) + _ = closeFunc() // Ignore errors during cleanup + return true + }) + + // Close channels to avoid goroutine leaks + close(tokenCh) + close(errorCh) + + // Count tokens and check their validity + var tokens []*token.Token + for t := range tokenCh { + tokens = append(tokens, t) + } + + // Collect and categorize errors + var startErrors, getTokenErrors, closeErrors, otherErrors []error + for err := range errorCh { + errStr := err.Error() + if strings.Contains(errStr, "failed to start token manager") { + startErrors = append(startErrors, err) + } else if strings.Contains(errStr, "failed to get token") { + getTokenErrors = append(getTokenErrors, err) + } else if strings.Contains(errStr, "failed to close token manager") { + closeErrors = append(closeErrors, err) + } else { + otherErrors = append(otherErrors, err) + t.Fatalf("Unexpected error during concurrent operations: %v", err) + } + } + + totalOps := startCount + getTokenCount + closeCount + expectedOps := int32(numGoroutines * numConcurrentOps) + + // Report operation counts + t.Logf("Concurrent test summary:") + t.Logf("- Total operations executed: %d (expected: %d)", totalOps, expectedOps) + t.Logf("- Start operations: %d (with %d errors)", startCount, len(startErrors)) + t.Logf("- GetToken operations: %d (with %d errors, %d successful)", + getTokenCount, len(getTokenErrors), len(tokens)) + t.Logf("- Close operations: %d (with %d errors)", closeCount, len(closeErrors)) + + // Some errors are expected due to concurrent operations + // but we should have received tokens successfully + assert.Equal(t, expectedOps, totalOps, "All operations should be accounted for") + assert.True(t, len(tokens) > 0, "Should have received tokens") + + // Verify the token manager still works after all the concurrent operations + finalListener := &concurrentTestTokenListener{ + onNextFunc: func(t *token.Token) { + // Just verify we get a token - don't use assert within this callback + if t == nil { + panic("Final token should not be nil") + } + }, + onErrorFunc: func(err error) { + t.Errorf("Unexpected error in final listener: %v", err) + }, + } + + closeFunc, err := tm.Start(finalListener) + if err != nil && err != ErrTokenManagerAlreadyStarted { + t.Fatalf("Failed to start token manager after concurrent operations: %v", err) + } + if closeFunc != nil { + defer closeFunc() + } + + // Get token one more time to verify everything still works + finalToken, err := tm.GetToken(true) + assert.NoError(t, err, "Should be able to get token after concurrent operations") + assert.NotNil(t, finalToken, "Final token should not be nil") +} + +// concurrentTestTokenListener is a test implementation of TokenListener for concurrent tests +type concurrentTestTokenListener struct { + onNextFunc func(*token.Token) + onErrorFunc func(error) +} + +func (l *concurrentTestTokenListener) OnNext(t *token.Token) { + if l.onNextFunc != nil { + l.onNextFunc(t) + } +} + +func (l *concurrentTestTokenListener) OnError(err error) { + if l.onErrorFunc != nil { + l.onErrorFunc(err) + } +} + +// concurrentMockIdentityProvider is a mock implementation of shared.IdentityProvider for concurrent tests +type concurrentMockIdentityProvider struct { + tokenCounter int + mutex sync.Mutex +} + +func (m *concurrentMockIdentityProvider) RequestToken(_ context.Context) (shared.IdentityProviderResponse, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.tokenCounter++ + + // Use the existing test JWT token which is already properly formatted + resp, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, testJWTToken) + if err != nil { + return nil, err + } + return resp, nil +} diff --git a/providers.go b/providers.go new file mode 100644 index 0000000..2d053d7 --- /dev/null +++ b/providers.go @@ -0,0 +1,155 @@ +package entraid + +import ( + "fmt" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis/go-redis/v9/auth" +) + +// CredentialsProviderOptions is a struct that holds the options for the credentials provider. +// It is used to configure the streaming credentials provider when requesting a token with a token manager. +type CredentialsProviderOptions struct { + // ClientID is the client ID of the identity. + // This is used to identify the identity when requesting a token. + ClientID string + + // TokenManagerOptions is the options for the token manager. + // This is used to configure the token manager when requesting a token. + TokenManagerOptions manager.TokenManagerOptions + + // tokenManagerFactory is a private field that can be injected from within the package. + // It is used to create a token manager for the credentials provider. + tokenManagerFactory func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) +} + +// defaultTokenManagerFactory is the default implementation of the token manager factory. +// It creates a new token manager using the provided identity provider and options. +func defaultTokenManagerFactory(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + return manager.NewTokenManager(provider, options) +} + +// getTokenManagerFactory returns the token manager factory to use. +// If no factory is provided, it returns the default implementation. +func (o *CredentialsProviderOptions) getTokenManagerFactory() func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + if o.tokenManagerFactory == nil { + return defaultTokenManagerFactory + } + return o.tokenManagerFactory +} + +// Managed identity type + +// ManagedIdentityCredentialsProviderOptions is a struct that holds the options for the managed identity credentials provider. +type ManagedIdentityCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + // It is used to specify the client ID, tenant ID, and scopes for the identity. + CredentialsProviderOptions + + // ManagedIdentityProviderOptions is the options for the managed identity provider. + // This is used to configure the managed identity provider when requesting a token. + identity.ManagedIdentityProviderOptions +} + +// NewManagedIdentityCredentialsProvider creates a new streaming credentials provider for managed identity. +// It uses the provided options to configure the provider. +// Use this when you want either a system assigned identity or a user assigned identity. +// The system assigned identity is automatically managed by Azure and does not require any additional configuration. +// The user assigned identity is a separate resource that can be managed independently. +func NewManagedIdentityCredentialsProvider(options ManagedIdentityCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the managed identity type. + idp, err := identity.NewManagedIdentityProvider(options.ManagedIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create managed identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + + return credentialsProvider, nil +} + +// ConfidentialCredentialsProviderOptions is a struct that holds the options for the confidential credentials provider. +// It is used to configure the credentials provider when requesting a token. +type ConfidentialCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + CredentialsProviderOptions + + // ConfidentialIdentityProviderOptions is the options for the confidential identity provider. + // This is used to configure the identity provider when requesting a token. + identity.ConfidentialIdentityProviderOptions +} + +// NewConfidentialCredentialsProvider creates a new confidential credentials provider. +// It uses client id and client credentials to authenticate with the identity provider. +// The client credentials can be either a client secret or a client certificate. +func NewConfidentialCredentialsProvider(options ConfidentialCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the client ID and client credentials. + idp, err := identity.NewConfidentialIdentityProvider(options.ConfidentialIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create confidential identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + return credentialsProvider, nil +} + +// DefaultAzureCredentialsProviderOptions is a struct that holds the options for the default azure credentials provider. +// It is used to configure the credentials provider when requesting a token. +type DefaultAzureCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + // It includes the clientId and TokenManagerOptions. + CredentialsProviderOptions + // DefaultAzureIdentityProviderOptions is the options for the default azure identity provider. + // This is used to configure the identity provider when requesting a token. + // It is used to specify the client ID, tenant ID, and scopes for the identity. + identity.DefaultAzureIdentityProviderOptions +} + +// NewDefaultAzureCredentialsProvider creates a new default azure credentials provider. +// It uses the default azure identity provider to authenticate with the identity provider. +// The default azure identity provider is a special type of identity provider that uses the default azure identity to authenticate. +// It is used to authenticate with the identity provider when requesting a token. +func NewDefaultAzureCredentialsProvider(options DefaultAzureCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the default azure identity type. + idp, err := identity.NewDefaultAzureIdentityProvider(options.DefaultAzureIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create default azure identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + return credentialsProvider, nil +} diff --git a/providers_test.go b/providers_test.go new file mode 100644 index 0000000..88d88b3 --- /dev/null +++ b/providers_test.go @@ -0,0 +1,588 @@ +package entraid + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/assert" +) + +func TestNewManagedIdentityCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options ManagedIdentityCredentialsProviderOptions + expectedError error + }{ + { + name: "valid managed identity options", + options: ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "system assigned identity", + options: ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: identity.SystemAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "invalid managed identity type", + options: ManagedIdentityCredentialsProviderOptions{ + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", + }, + }, + expectedError: errors.New("invalid managed identity type"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + provider, err := NewManagedIdentityCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + tk, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, tk.RawCredentials()) + assert.NoError(t, err) + } + }) + } +} + +func TestNewConfidentialCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options ConfidentialCredentialsProviderOptions + expectedError error + }{ + { + name: "valid confidential options with client secret", + options: ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + }, + expectedError: nil, + }, + { + name: "missing required fields", + options: ConfidentialCredentialsProviderOptions{ + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + CredentialsType: identity.ClientSecretCredentialType, + }, + }, + expectedError: errors.New("client ID is required"), + }, + { + name: "invalid credentials type", + options: ConfidentialCredentialsProviderOptions{ + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: "invalid-type", + }, + }, + expectedError: errors.New("invalid credentials type"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + credentials, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, credentials.RawCredentials()) + assert.NoError(t, err) + } + }) + } +} + +func TestNewDefaultAzureCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options DefaultAzureCredentialsProviderOptions + expectedError error + }{ + { + name: "valid default azure options", + options: DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "empty options", + options: DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + }, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + provider, err := NewDefaultAzureCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + tk, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, tk.RawCredentials()) + assert.NoError(t, err) + } + }) + } +} + +func TestCredentialsProviderInterface(t *testing.T) { + // Test that all providers implement the StreamingCredentialsProvider interface + tests := []struct { + name string + provider auth.StreamingCredentialsProvider + }{ + { + name: "managed identity provider", + provider: func() auth.StreamingCredentialsProvider { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + p, _ := NewManagedIdentityCredentialsProvider(options) + return p + }(), + }, + { + name: "confidential provider", + provider: func() auth.StreamingCredentialsProvider { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + p, _ := NewConfidentialCredentialsProvider(options) + return p + }(), + }, + { + name: "default azure provider", + provider: func() auth.StreamingCredentialsProvider { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + p, _ := NewDefaultAzureCredentialsProvider(options) + return p + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that the provider implements the interface by calling its methods + // Note: These are simplified tests as actual authentication would require Azure credentials + listener := &mockCredentialsListener{} + credentials, cancel, err := tt.provider.Subscribe(listener) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + assert.NoError(t, err) + }) + } +} + +func TestNewManagedIdentityCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewConfidentialCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewDefaultAzureCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewManagedIdentityCredentialsProvider_TokenManagerStartError(t *testing.T) { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } + + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) +} + +func TestNewConfidentialCredentialsProvider_TokenManagerStartError(t *testing.T) { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } + + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) +} + +func TestNewDefaultAzureCredentialsProvider_TokenManagerStartError(t *testing.T) { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } + + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) +} diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go new file mode 100644 index 0000000..f4bb7a3 --- /dev/null +++ b/shared/identity_provider_response.go @@ -0,0 +1,69 @@ +package shared + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/internal" + "github.com/redis-developer/go-redis-entraid/token" +) + +const ( + // ResponseTypeAuthResult is the type of the auth result. + ResponseTypeAuthResult = "AuthResult" + // ResponseTypeAccessToken is the type of the access token. + ResponseTypeAccessToken = "AccessToken" + // ResponseTypeRawToken is the type of the response when you have a raw string. + ResponseTypeRawToken = "RawToken" +) + +// IdentityProviderResponseParser is an interface that defines the methods for parsing the identity provider response. +// It is used to parse the response from the identity provider and extract the token. +// If not provided, the default implementation will be used. +type IdentityProviderResponseParser interface { + ParseResponse(response IdentityProviderResponse) (*token.Token, error) +} + +// IdentityProviderResponse is an interface that defines the +// type method for the identity provider response. It is used to +// identify the type of response returned by the identity provider. +// The type can be either AuthResult, AccessToken, or RawToken. You can +// use this interface to check the type of the response and handle it accordingly. +type IdentityProviderResponse interface { + // Type returns the type of identity provider response + Type() string +} + +// AuthResultIDPResponse is an interface that defines the method for getting the auth result. +type AuthResultIDPResponse interface { + AuthResult() public.AuthResult +} + +// AccessTokenIDPResponse is an interface that defines the method for getting the access token. +type AccessTokenIDPResponse interface { + AccessToken() azcore.AccessToken +} + +// RawTokenIDPResponse is an interface that defines the method for getting the raw token. +type RawTokenIDPResponse interface { + RawToken() string +} + +// IdentityProvider is an interface that defines the methods for an identity provider. +// It is used to request a token for authentication. +// The identity provider is responsible for providing the raw authentication token. +type IdentityProvider interface { + // RequestToken requests a token from the identity provider. + // The context is passed to the request to allow for cancellation and timeouts. + // It returns the token, the expiration time, and an error if any. + RequestToken(ctx context.Context) (IdentityProviderResponse, error) +} + +// NewIDPResponse creates a new auth result based on the type provided. +// It returns an IdentityProviderResponse interface. +// Type can be either AuthResult, AccessToken, or RawToken. +// Second argument is the result of the type provided in the first argument. +func NewIDPResponse(responseType string, result interface{}) (IdentityProviderResponse, error) { + return internal.NewIDPResp(responseType, result) +} diff --git a/shared/identity_provider_response_test.go b/shared/identity_provider_response_test.go new file mode 100644 index 0000000..715dce2 --- /dev/null +++ b/shared/identity_provider_response_test.go @@ -0,0 +1,330 @@ +package shared + +import ( + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" +) + +// Mock implementations for testing +type mockIDPResponse struct { + responseType string + authResult *public.AuthResult + accessToken *azcore.AccessToken + rawToken string +} + +func (m *mockIDPResponse) Type() string { + return m.responseType +} + +func (m *mockIDPResponse) AuthResult() public.AuthResult { + if m.authResult == nil { + return public.AuthResult{} + } + return *m.authResult +} + +func (m *mockIDPResponse) AccessToken() azcore.AccessToken { + if m.accessToken == nil { + return azcore.AccessToken{} + } + return *m.accessToken +} + +func (m *mockIDPResponse) RawToken() string { + return m.rawToken +} + +type mockIDPParser struct { + parseError error + token *token.Token +} + +func (m *mockIDPParser) ParseResponse(response IdentityProviderResponse) (*token.Token, error) { + if m.parseError != nil { + return nil, m.parseError + } + return m.token, nil +} + +type mockIDP struct { + response IdentityProviderResponse + err error +} + +func (m *mockIDP) RequestToken() (IdentityProviderResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestNewIDPResponse(t *testing.T) { + tests := []struct { + name string + responseType string + result interface{} + expectedError string + }{ + { + name: "Valid AuthResult pointer", + responseType: ResponseTypeAuthResult, + result: &public.AuthResult{}, + }, + { + name: "Valid AuthResult value", + responseType: ResponseTypeAuthResult, + result: public.AuthResult{}, + }, + { + name: "Valid AccessToken pointer", + responseType: ResponseTypeAccessToken, + result: &azcore.AccessToken{Token: "test-token"}, + }, + { + name: "Valid AccessToken value", + responseType: ResponseTypeAccessToken, + result: azcore.AccessToken{Token: "test-token"}, + }, + { + name: "Valid RawToken string", + responseType: ResponseTypeRawToken, + result: "test-token", + }, + { + name: "Valid RawToken string pointer", + responseType: ResponseTypeRawToken, + result: stringPtr("test-token"), + }, + { + name: "Nil result", + responseType: ResponseTypeAuthResult, + result: nil, + expectedError: "result cannot be nil", + }, + { + name: "Nil string pointer", + responseType: ResponseTypeRawToken, + result: (*string)(nil), + expectedError: "raw token cannot be nil", + }, + { + name: "Invalid AuthResult type", + responseType: ResponseTypeAuthResult, + result: "not-an-auth-result", + expectedError: "invalid auth result type: expected public.AuthResult or *public.AuthResult", + }, + { + name: "Invalid AccessToken type", + responseType: ResponseTypeAccessToken, + result: "not-an-access-token", + expectedError: "invalid access token type: expected azcore.AccessToken or *azcore.AccessToken", + }, + { + name: "Invalid RawToken type", + responseType: ResponseTypeRawToken, + result: 123, + expectedError: "invalid raw token type: expected string or *string", + }, + { + name: "Invalid response type", + responseType: "InvalidType", + result: "test", + expectedError: "unsupported identity provider response type: InvalidType", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := NewIDPResponse(tt.responseType, tt.result) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, resp) + return + } + + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.responseType, resp.Type()) + + switch tt.responseType { + case ResponseTypeAuthResult: + assert.NotNil(t, resp.(AuthResultIDPResponse).AuthResult()) + case ResponseTypeAccessToken: + assert.NotNil(t, resp.(AccessTokenIDPResponse).AccessToken()) + assert.NotEmpty(t, resp.(AccessTokenIDPResponse).AccessToken().Token) + case ResponseTypeRawToken: + assert.NotEmpty(t, resp.(RawTokenIDPResponse).RawToken()) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} + +func TestIdentityProviderResponse(t *testing.T) { + now := time.Now() + expires := now.Add(time.Hour) + + authResult := &public.AuthResult{ + AccessToken: "test-access-token", + ExpiresOn: expires, + } + + accessToken := &azcore.AccessToken{ + Token: "test-access-token", + ExpiresOn: expires, + } + + tests := []struct { + name string + response *mockIDPResponse + expectedType string + }{ + { + name: "AuthResult response", + response: &mockIDPResponse{ + responseType: ResponseTypeAuthResult, + authResult: authResult, + }, + expectedType: ResponseTypeAuthResult, + }, + { + name: "AccessToken response", + response: &mockIDPResponse{ + responseType: ResponseTypeAccessToken, + accessToken: accessToken, + }, + expectedType: ResponseTypeAccessToken, + }, + { + name: "RawToken response", + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-raw-token", + }, + expectedType: ResponseTypeRawToken, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expectedType, tt.response.Type()) + + switch tt.expectedType { + case ResponseTypeAuthResult: + result := tt.response.AuthResult() + assert.Equal(t, authResult.AccessToken, result.AccessToken) + assert.Equal(t, authResult.ExpiresOn, result.ExpiresOn) + case ResponseTypeAccessToken: + token := tt.response.AccessToken() + assert.Equal(t, accessToken.Token, token.Token) + assert.Equal(t, accessToken.ExpiresOn, token.ExpiresOn) + case ResponseTypeRawToken: + assert.Equal(t, "test-raw-token", tt.response.RawToken()) + } + }) + } +} + +func TestIdentityProvider(t *testing.T) { + tests := []struct { + name string + provider *mockIDP + wantErr bool + }{ + { + name: "Successful token request", + provider: &mockIDP{ + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + }, + wantErr: false, + }, + { + name: "Failed token request", + provider: &mockIDP{ + err: assert.AnError, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response, err := tt.provider.RequestToken() + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, ResponseTypeRawToken, response.Type()) + assert.Equal(t, "test-token", response.(RawTokenIDPResponse).RawToken()) + } + }) + } +} + +func TestIdentityProviderResponseParser(t *testing.T) { + now := time.Now() + expires := now.Add(time.Hour) + testToken := token.New("test-user", "test-password", "test-token", expires, now, int64(time.Hour.Seconds())) + + tests := []struct { + name string + parser *mockIDPParser + response IdentityProviderResponse + wantErr bool + wantToken *token.Token + }{ + { + name: "Successful parse", + parser: &mockIDPParser{ + token: testToken, + }, + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + wantErr: false, + wantToken: testToken, + }, + { + name: "Failed parse", + parser: &mockIDPParser{ + parseError: assert.AnError, + }, + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + wantErr: true, + wantToken: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := tt.parser.ParseResponse(tt.response) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, token) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantToken, token) + } + }) + } +} diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000..fafce60 --- /dev/null +++ b/token/token.go @@ -0,0 +1,83 @@ +package token + +import ( + "time" + + "github.com/redis/go-redis/v9/auth" +) + +// Ensure Token implements the auth.Credentials interface. +var _ auth.Credentials = (*Token)(nil) + +// New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live. +// NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance. +func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token { + return &Token{ + username: username, + password: password, + expiresOn: expiresOn, + receivedAt: receivedAt, + ttl: ttl, + rawToken: rawToken, + } +} + +// Token represents parsed authentication token used to access the Redis server. +// It implements the auth.Credentials interface. +type Token struct { + // username is the username of the user. + username string + // password is the password of the user. + password string + // expiresOn is the expiration time of the token. + expiresOn time.Time + // ttl is the time to live of the token. + ttl int64 + // rawToken is the authentication token. + rawToken string + // receivedAt is the time when the token was received. + receivedAt time.Time +} + +// BasicAuth returns the username and password for basic authentication. +func (t *Token) BasicAuth() (string, string) { + return t.username, t.password +} + +// RawCredentials returns the raw credentials for authentication. +func (t *Token) RawCredentials() string { + return t.RawToken() +} + +// RawToken returns the raw token. +func (t *Token) RawToken() string { + return t.rawToken +} + +// ReceivedAt returns the time when the token was received. +func (t *Token) ReceivedAt() time.Time { + return t.receivedAt +} + +// ExpirationOn returns the expiration time of the token. +func (t *Token) ExpirationOn() time.Time { + return t.expiresOn +} + +// TTL returns the time to live of the token. +func (t *Token) TTL() int64 { + return t.ttl +} + +// Copy creates a copy of the token. +func (t *Token) Copy() *Token { + return copyToken(t) +} + +// copyToken creates a copy of the token. +func copyToken(token *Token) *Token { + if token == nil { + return nil + } + return New(token.username, token.password, token.rawToken, token.expiresOn, token.receivedAt, token.ttl) +} diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..72a54ad --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,166 @@ +package token + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + t.Parallel() + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New("username", "password", "rawToken", expiration, receivedAt, ttl) + assert.Equal(t, "username", token.username) + assert.Equal(t, "password", token.password) + assert.Equal(t, "rawToken", token.rawToken) + assert.Equal(t, expiration, token.expiresOn) + assert.Equal(t, receivedAt, token.receivedAt) + assert.Equal(t, ttl, token.ttl) +} + +func TestBasicAuth(t *testing.T) { + t.Parallel() + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + baUsername, baPassword := token.BasicAuth() + assert.Equal(t, username, baUsername) + assert.Equal(t, password, baPassword) +} + +func TestRawCredentials(t *testing.T) { + t.Parallel() + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + rawCredentials := token.RawCredentials() + assert.Equal(t, rawToken, rawCredentials) + assert.Contains(t, rawCredentials, username) + assert.Contains(t, rawCredentials, password) +} + +func TestExpirationOn(t *testing.T) { + t.Parallel() + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + expirationOn := token.ExpirationOn() + assert.True(t, expirationOn.After(time.Now())) + assert.Equal(t, expiration, expirationOn) +} + +func TestTokenTTL(t *testing.T) { + t.Parallel() + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + assert.Equal(t, ttl, token.TTL()) +} + +func TestCopyToken(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + copiedToken := copyToken(token) + + assert.Equal(t, token.username, copiedToken.username) + assert.Equal(t, token.password, copiedToken.password) + assert.Equal(t, token.rawToken, copiedToken.rawToken) + assert.Equal(t, token.ttl, copiedToken.ttl) + assert.Equal(t, token.expiresOn, copiedToken.expiresOn) + assert.Equal(t, token.receivedAt, copiedToken.receivedAt) + + // change the copied token + copiedToken.expiresOn = time.Now().Add(-1 * time.Hour) + assert.NotEqual(t, token.expiresOn, copiedToken.expiresOn) + + // copy nil + copiedToken = copyToken(nil) + assert.Nil(t, copiedToken) + // copy empty token + copiedToken = copyToken(&Token{}) + assert.NotNil(t, copiedToken) + anotherCopy := copiedToken.Copy() + anotherCopy.rawToken = "changed" + assert.NotEqual(t, copiedToken, anotherCopy) +} + +func TestTokenReceivedAt(t *testing.T) { + t.Parallel() + // Create a token with a specific receivedAt time + receivedAt := time.Now() + token := New("username", "password", "rawToken", time.Now(), receivedAt, 3600) + + assert.True(t, token.receivedAt.After(time.Now().Add(-1*time.Hour))) + assert.True(t, token.receivedAt.Before(time.Now().Add(1*time.Hour))) + + // Check if the receivedAt time is set correctly + assert.Equal(t, receivedAt, token.ReceivedAt()) + + tcopiedToken := token.Copy() + // Check if the copied token has the same receivedAt time + assert.Equal(t, receivedAt, tcopiedToken.ReceivedAt()) + // Check if the copied token is not the same instance as the original token + assert.NotSame(t, token, tcopiedToken) + // Check if the copied token is a new instance + assert.NotNil(t, tcopiedToken) +} + +func BenchmarkNew(b *testing.B) { + now := time.Now() + b.ResetTimer() + for i := 0; i < b.N; i++ { + New("username", "password", "rawToken", now, now, 3600) + } +} + +func BenchmarkBasicAuth(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.BasicAuth() + } +} + +func BenchmarkRawCredentials(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.RawCredentials() + } +} + +func BenchmarkExpirationOn(b *testing.B) { + token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.ExpirationOn() + } +} + +func BenchmarkCopyToken(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.Copy() + } +} diff --git a/token_listener.go b/token_listener.go new file mode 100644 index 0000000..8515fda --- /dev/null +++ b/token_listener.go @@ -0,0 +1,36 @@ +package entraid + +import ( + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/token" +) + +// entraidTokenListener implements the TokenListener interface for the entraidCredentialsProvider. +// It listens for token updates and errors from the token manager and notifies the credentials provider. +type entraidTokenListener struct { + cp *entraidCredentialsProvider +} + +// tokenListenerFromCP creates a new entraidTokenListener from the given entraidCredentialsProvider. +// It is used to listen for token updates and errors from the token manager. +// This function is typically called when starting the token manager. +// It returns a pointer to the entraidTokenListener instance that is created from the credentials provider. +func tokenListenerFromCP(cp *entraidCredentialsProvider) manager.TokenListener { + return &entraidTokenListener{ + cp, + } +} + +// OnTokenNext is called when the token manager receives a new token. +// It notifies the credentials provider with the new token. +// This function is typically called when the token manager successfully retrieves a token. +func (l *entraidTokenListener) OnNext(t *token.Token) { + l.cp.onTokenNext(t) +} + +// OnTokenError is called when the token manager encounters an error. +// It notifies the credentials provider with the error. +// This function is typically called when the token manager fails to retrieve a token. +func (l *entraidTokenListener) OnError(err error) { + l.cp.onTokenError(err) +} diff --git a/token_listener_test.go b/token_listener_test.go new file mode 100644 index 0000000..d2d996c --- /dev/null +++ b/token_listener_test.go @@ -0,0 +1,46 @@ +package entraid + +import ( + "errors" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenListenerFromCP(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + require.NotNil(t, listener) + _, ok := listener.(*entraidTokenListener) + assert.True(t, ok, "listener should be of type entraidTokenListener") +} + +func TestOnTokenNext(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + now := time.Now() + testToken := token.New("test-user", "test-pass", "test-token", now.Add(time.Hour), now, 3600) + + listener.OnNext(testToken) + + // Since we can't directly access the internal state of entraidCredentialsProvider, + // we'll verify that the listener was created and the call didn't panic + assert.NotNil(t, listener) +} + +func TestOnTokenError(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + testError := errors.New("test error") + listener.OnError(testError) + + // Since we can't directly access the internal state of entraidCredentialsProvider, + // we'll verify that the listener was created and the call didn't panic + assert.NotNil(t, listener) +} diff --git a/version.go b/version.go index 4770c2f..fbcb668 100644 --- a/version.go +++ b/version.go @@ -1,4 +1,4 @@ -package redis +package entraid const version = "0.0.1" diff --git a/version_test.go b/version_test.go index e95de1c..8351bea 100644 --- a/version_test.go +++ b/version_test.go @@ -1,4 +1,4 @@ -package redis +package entraid import ( "testing"