From 9484ecbea3db2463a22e94d5b52a256ba276328d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 6 Jan 2025 18:49:22 +0000 Subject: [PATCH] Added support for expires_in Added support for expires_in different time formats for expire_in support for principle_id for app service in managed identity --- apps/confidential/confidential_test.go | 50 +++++----- apps/internal/base/base.go | 9 +- apps/internal/base/base_test.go | 48 ++++----- .../base/storage/partitioned_storage.go | 9 +- .../base/storage/partitioned_storage_test.go | 28 +++--- apps/internal/base/storage/storage.go | 2 +- apps/internal/base/storage/storage_test.go | 18 ++-- apps/internal/json/types/time/time.go | 49 ++++++++-- apps/internal/oauth/oauth.go | 2 +- .../ops/accesstokens/accesstokens_test.go | 98 ++++++++++++++++--- .../internal/oauth/ops/accesstokens/tokens.go | 61 ++---------- apps/managedidentity/managedidentity.go | 3 +- apps/managedidentity/managedidentity_test.go | 46 ++++----- apps/public/public_test.go | 12 +-- apps/tests/benchmarks/confidential.go | 12 +-- apps/tests/performance/performance_test.go | 10 +- 16 files changed, 265 insertions(+), 192 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 18ed2c24..1ec9aeba 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -342,11 +342,11 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } { t.Run("", func(t *testing.T) { tr := accesstokens.TokenResponse{ - AccessToken: token, - RefreshToken: refresh, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + AccessToken: token, + RefreshToken: refresh, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ PreferredUsername: params.preferredUsername, UPN: params.upn, @@ -461,12 +461,12 @@ func TestADFSTokenCaching(t *testing.T) { } fakeAT := fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ - AccessToken: "at1", - RefreshToken: "rt", - TokenType: "bearer", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + AccessToken: "at1", + RefreshToken: "rt", + TokenType: "bearer", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ ExpirationTime: time.Now().Add(time.Hour).Unix(), Name: "A", @@ -593,9 +593,9 @@ func TestNewCredFromCert(t *testing.T) { } t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) { client, err := fakeClient(accesstokens.TokenResponse{ - AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + AccessToken: token, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, cred, fakeAuthority, opts...) if err != nil { t.Fatal(err) @@ -743,7 +743,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { t.Fatal("token provider wasn't invoked") } if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn { - t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v) + t.Fatalf("expected ExpiresOnCalculated ~= %d seconds, got %d", expiresIn, v) } if ar.AccessToken != expectedToken { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) @@ -1383,11 +1383,11 @@ func TestWithAuthenticationScheme(t *testing.T) { t.Fatal(err) } client, err := fakeClient(accesstokens.TokenResponse{ - AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, - TokenType: "TokenType", + AccessToken: token, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "TokenType", }, cred, fakeAuthority) if err != nil { t.Fatal(err) @@ -1423,11 +1423,11 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) { t.Fatal(err) } client, err := fakeClient(accesstokens.TokenResponse{ - AccessToken: token, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, - TokenType: "Bearer", + AccessToken: token, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "Bearer", }, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant) if err != nil { t.Fatal(err) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 2e47c8f8..a441e09f 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -145,7 +145,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco Account: account, IDToken: tokenResponse.IDToken, AccessToken: tokenResponse.AccessToken, - ExpiresOn: tokenResponse.ExpiresOn.T, + ExpiresOn: getExpiryTime(tokenResponse), GrantedScopes: tokenResponse.GrantedScopes.Slice, Metadata: AuthResultMetadata{ TokenSource: IdentityProvider, @@ -153,6 +153,13 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco }, nil } +func getExpiryTime(tokenResponse accesstokens.TokenResponse) time.Time { + if tokenResponse.ExpiresOnCalculated.T.IsZero() || tokenResponse.ExpiresOnCalculated.T.Equal(time.Unix(0, 0)) { + return tokenResponse.ExpiresOn.T + } + return tokenResponse.ExpiresOnCalculated.T +} + // Client is a base client that provides access to common methods and primatives that // can be used by multiple clients. type Client struct { diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 7b7f52a7..0212a11e 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -49,12 +49,12 @@ func fakeClient(t *testing.T, opts ...Option) Client { } client.Token.AccessTokens = &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ - AccessToken: fakeAccessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - FamilyID: "family-id", - GrantedScopes: accesstokens.Scopes{Slice: testScopes}, - IDToken: fakeIDToken, - RefreshToken: fakeRefreshToken, + AccessToken: fakeAccessToken, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + FamilyID: "family-id", + GrantedScopes: accesstokens.Scopes{Slice: testScopes}, + IDToken: fakeIDToken, + RefreshToken: fakeRefreshToken, }, } client.Token.Authority = &fake.Authority{ @@ -133,12 +133,12 @@ func TestAcquireTokenSilentScopes(t *testing.T) { AuthnScheme: &authority.BearerAuthenticationScheme{}, }, accesstokens.TokenResponse{ - AccessToken: fakeAccessToken, - ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, - IDToken: fakeIDToken, - RefreshToken: fakeRefreshToken, + AccessToken: fakeAccessToken, + ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, + IDToken: fakeIDToken, + RefreshToken: fakeRefreshToken, }, ) storage.FakeValidate = nil @@ -177,10 +177,10 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) { AuthnScheme: &authority.BearerAuthenticationScheme{}, }, accesstokens.TokenResponse{ - AccessToken: expectedToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: grantedScopes}, - TokenType: "Bearer", + AccessToken: expectedToken, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: grantedScopes}, + TokenType: "Bearer", }, ) if err != nil { @@ -334,10 +334,10 @@ func TestCreateAuthenticationResult(t *testing.T) { { desc: "no declined scopes", input: accesstokens.TokenResponse{ - AccessToken: "accessToken", - ExpiresOn: internalTime.DurationTime{T: future}, - GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, - DeclinedScopes: nil, + AccessToken: "accessToken", + ExpiresOnCalculated: internalTime.DurationTime{T: future}, + GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, + DeclinedScopes: nil, }, want: AuthResult{ AccessToken: "accessToken", @@ -352,10 +352,10 @@ func TestCreateAuthenticationResult(t *testing.T) { { desc: "declined scopes", input: accesstokens.TokenResponse{ - AccessToken: "accessToken", - ExpiresOn: internalTime.DurationTime{T: future}, - GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, - DeclinedScopes: []string{"openid"}, + AccessToken: "accessToken", + ExpiresOnCalculated: internalTime.DurationTime{T: future}, + GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}}, + DeclinedScopes: []string{"openid"}, }, err: true, }, diff --git a/apps/internal/base/storage/partitioned_storage.go b/apps/internal/base/storage/partitioned_storage.go index c0931833..7f3cb7bd 100644 --- a/apps/internal/base/storage/partitioned_storage.go +++ b/apps/internal/base/storage/partitioned_storage.go @@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes realm, clientID, cachedAt, - tokenResponse.ExpiresOn.T, + getExpiryTime(tokenResponse), tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, @@ -177,6 +177,13 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes return account, nil } +func getExpiryTime(tokenResponse accesstokens.TokenResponse) time.Time { + if tokenResponse.ExpiresOnCalculated.T.IsZero() || tokenResponse.ExpiresOnCalculated.T.Equal(time.Unix(0, 0)) { + return tokenResponse.ExpiresOn.T + } + return tokenResponse.ExpiresOnCalculated.T +} + func (m *PartitionedManager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { md, err := m.aadMetadataFromCache(ctx, authorityInfo) if err != nil { diff --git a/apps/internal/base/storage/partitioned_storage_test.go b/apps/internal/base/storage/partitioned_storage_test.go index 86859cf2..42faafb7 100644 --- a/apps/internal/base/storage/partitioned_storage_test.go +++ b/apps/internal/base/storage/partitioned_storage_test.go @@ -57,13 +57,13 @@ func TestOBOAccessTokenScopes(t *testing.T) { _, err := mgr.Write( ap, accesstokens.TokenResponse{ - AccessToken: scope[0] + "-at", - ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: scope}, - IDToken: idt, - RefreshToken: upn + "-rt", - TokenType: "Bearer", + AccessToken: scope[0] + "-at", + ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: scope}, + IDToken: idt, + RefreshToken: upn + "-rt", + TokenType: "Bearer", }, ) if err != nil { @@ -119,13 +119,13 @@ func TestOBOPartitioning(t *testing.T) { account, err := mgr.Write( authParams[i], accesstokens.TokenResponse{ - AccessToken: upn + "-at", - ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: scopes}, - IDToken: idt, - RefreshToken: upn + "-rt", - TokenType: "Bearer", + AccessToken: upn + "-at", + ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID}, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: scopes}, + IDToken: idt, + RefreshToken: upn + "-rt", + TokenType: "Bearer", }, ) if err != nil { diff --git a/apps/internal/base/storage/storage.go b/apps/internal/base/storage/storage.go index 799ae87e..49553f3e 100644 --- a/apps/internal/base/storage/storage.go +++ b/apps/internal/base/storage/storage.go @@ -194,7 +194,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces realm, clientID, cachedAt, - tokenResponse.ExpiresOn.T, + getExpiryTime(tokenResponse), tokenResponse.ExtExpiresOn.T, target, tokenResponse.AccessToken, diff --git a/apps/internal/base/storage/storage_test.go b/apps/internal/base/storage/storage_test.go index 0570115c..66e0844e 100644 --- a/apps/internal/base/storage/storage_test.go +++ b/apps/internal/base/storage/storage_test.go @@ -1009,15 +1009,15 @@ func TestWrite(t *testing.T) { } expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)} tokenResponse := accesstokens.TokenResponse{ - AccessToken: "accessToken", - RefreshToken: "refreshToken", - IDToken: idToken, - FamilyID: "fid", - ClientInfo: clientInfo, - GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}}, - ExpiresOn: expiresOn, - ExtExpiresOn: internalTime.DurationTime{T: now}, - TokenType: "Bearer", + AccessToken: "accessToken", + RefreshToken: "refreshToken", + IDToken: idToken, + FamilyID: "fid", + ClientInfo: clientInfo, + GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}}, + ExpiresOnCalculated: expiresOn, + ExtExpiresOn: internalTime.DurationTime{T: now}, + TokenType: "Bearer", } authInfo := authority.Info{Host: "env", Tenant: "realm", AuthorityType: accAuth} authParams := authority.AuthParams{ diff --git a/apps/internal/json/types/time/time.go b/apps/internal/json/types/time/time.go index a1c99621..e06fc2a1 100644 --- a/apps/internal/json/types/time/time.go +++ b/apps/internal/json/types/time/time.go @@ -7,6 +7,7 @@ package time import ( "fmt" + "regexp" "strconv" "strings" "time" @@ -40,9 +41,9 @@ func (u *Unix) UnmarshalJSON(b []byte) error { // of a duration from now into a time.Time object. // Note: I'm not sure this is the best way to do this. What happens is we get a field // called "expires_in" that represents the seconds from now that this expires. We -// turn that into a time we call .ExpiresOn. But maybe we should be recording +// turn that into a time we call .ExpiresOnCalculated. But maybe we should be recording // when the token was received at .TokenRecieved and .ExpiresIn should remain as a duration. -// Then we could have a method called ExpiresOn(). Honestly, the whole thing is +// Then we could have a method called ExpiresOnCalculated(). Honestly, the whole thing is // bad because the server doesn't return a concrete time. I think this is // cleaner, but its not great either. type DurationTime struct { @@ -59,12 +60,48 @@ func (d DurationTime) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf("%d", int64(dt*time.Second))), nil } -// UnmarshalJSON implements encoding/json.UnmarshalJSON(). -func (d *DurationTime) UnmarshalJSON(b []byte) error { - i, err := strconv.Atoi(strings.Trim(string(b), `"`)) +// UnmarshalJSON custom unmarshaler for DurationTime +func (t *DurationTime) UnmarshalJSON(b []byte) error { + // Remove the quotes around the JSON string + str := strings.Trim(string(b), `"`) + + // Try parsing as Unix timestamp (seconds since the Unix epoch) + if len(str) == 10 { + if unixTimestamp, err := strconv.ParseInt(str, 10, 64); err == nil { + t.T = time.Unix(unixTimestamp, 0) + return nil + } + } + + // Try parsing as ISO 8601 format (e.g., "2024-10-18T19:51:37.0000000+00:00") + iso8601Layout := "2006-01-02T15:04:05.9999999-07:00" + if parsedTime, err := time.Parse(iso8601Layout, str); err == nil { + t.T = parsedTime + return nil + } + + // Try parsing as MM/dd/yyyy HH:mm:ss format (e.g., "10/18/2024 19:51:37") + // Create regex pattern for MM/dd/yyyy HH:mm:ss + mmddyyyyLayout := `^(\d{2})/(\d{2})/(\d{4}) (\d{2}):(\d{2}):(\d{2})$` + if matched, _ := regexp.MatchString(mmddyyyyLayout, str); matched { + parsedTime, err := time.Parse("01/02/2006 15:04:05", str) + if err == nil { + t.T = parsedTime + return nil + } + } + + // Try parsing as yyyy-MM-dd HH:mm:ss format (e.g., "2024-10-18 19:51:37") + if parsedTime, err := time.Parse("2006-01-02 15:04:05", str); err == nil { + t.T = parsedTime + return nil + } + + i, err := strconv.Atoi(str) if err != nil { return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err) } - d.T = time.Now().Add(time.Duration(i) * time.Second) + t.T = time.Now().Add(time.Duration(i) * time.Second) return nil + } diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index e0653134..2863e374 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -122,7 +122,7 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams return accesstokens.TokenResponse{ TokenType: authParams.AuthnScheme.AccessTokenType(), AccessToken: tr.AccessToken, - ExpiresOn: internalTime.DurationTime{ + ExpiresOnCalculated: internalTime.DurationTime{ T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second), }, GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes}, diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index 7d8a87a9..95df4696 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -758,19 +758,21 @@ func TestTokenResponseUnmarshal(t *testing.T) { }, { desc: "Success", - payload: ` + payload: fmt.Sprintf(` { "access_token": "secret", "expires_in": 86399, + "expires_on": %d, "ext_expires_in": 86399, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" - }`, + }`, time.Now().Add(time.Hour*10).Unix()), want: TokenResponse{ - AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, - GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + AccessToken: "secret", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour * 10)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", UTID: "utid", @@ -784,15 +786,88 @@ func TestTokenResponseUnmarshal(t *testing.T) { { "access_token": "secret", "expires_on": %d, + "expires_in": 86399, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, time.Now().Add(time.Hour*10).Unix()), + want: TokenResponse{ + AccessToken: "secret", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour * 10)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, { + desc: "Success", + payload: fmt.Sprintf(` + { + "access_token": "secret", + "expires_on": "%s", + "expires_in": 86399, "ext_expires_in": 86399, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" - }`, time.Now().Add(time.Hour).Unix()), + }`, time.Now().Add(time.Hour*10).Format("2006-01-02T15:04:05.0000000-07:00")), want: TokenResponse{ - AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, - GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + AccessToken: "secret", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour * 10)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success", + payload: fmt.Sprintf(` + { + "access_token": "secret", + "expires_on": "%s", + "expires_in": 86399, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, time.Now().Add(time.Hour*10).Format("2006-01-02 15:04:05")), + want: TokenResponse{ + AccessToken: "secret", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour * 10)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success", + payload: fmt.Sprintf(` + { + "access_token": "secret", + "expires_on": "%s", + "expires_in": 86399, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, time.Now().Add(time.Hour*10).Format("01/02/2006 15:04:05")), + want: TokenResponse{ + AccessToken: "secret", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour * 10)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", UTID: "utid", @@ -804,7 +879,6 @@ func TestTokenResponseUnmarshal(t *testing.T) { for _, test := range tests { jwtDecoder = test.jwtDecoder - got := TokenResponse{} err := json.Unmarshal([]byte(test.payload), &got) switch { diff --git a/apps/internal/oauth/ops/accesstokens/tokens.go b/apps/internal/oauth/ops/accesstokens/tokens.go index 1c5085a2..a9071eb3 100644 --- a/apps/internal/oauth/ops/accesstokens/tokens.go +++ b/apps/internal/oauth/ops/accesstokens/tokens.go @@ -170,65 +170,20 @@ type TokenResponse struct { RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` - FamilyID string `json:"foci"` - IDToken IDToken `json:"id_token"` - ClientInfo ClientInfo `json:"client_info"` - ExpiresOn internalTime.DurationTime `json:"-"` - ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` - GrantedScopes Scopes `json:"scope"` - DeclinedScopes []string // This is derived + FamilyID string `json:"foci"` + IDToken IDToken `json:"id_token"` + ClientInfo ClientInfo `json:"client_info"` + ExpiresOnCalculated internalTime.DurationTime `json:"expires_in,omitempty"` + ExpiresOn internalTime.DurationTime `json:"expires_on,omitempty"` + ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` + GrantedScopes Scopes `json:"scope"` + DeclinedScopes []string // This is derived AdditionalFields map[string]interface{} scopesComputed bool } -func (tr *TokenResponse) UnmarshalJSON(data []byte) error { - type Alias TokenResponse - aux := &struct { - ExpiresIn json.Number `json:"expires_in"` - ExpiresOn json.Number `json:"expires_on"` - *Alias - }{ - Alias: (*Alias)(tr), - } - - // Unmarshal the JSON data into the aux struct - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - // Helper function to parse JSON number into int64 - parseDuration := func(num json.Number) (int64, error) { - if num == "" { - return 0, nil - } - return num.Int64() - } - - // Try to parse ExpiresIn first, then fallback to ExpiresOn - if duration, err := parseDuration(aux.ExpiresIn); err != nil { - return err - } else if duration > 0 { - tr.ExpiresOn = internalTime.DurationTime{T: time.Now().Add(time.Duration(duration) * time.Second)} - } else if duration == 0 || aux.ExpiresOn != "" { - // If ExpiresIn is zero, check ExpiresOn - if duration, err := parseDuration(aux.ExpiresOn); err != nil { - return err - } else if duration > 0 { - tr.ExpiresOn = internalTime.DurationTime{T: time.Unix(duration, 0)} - println(tr.ExpiresOn.T.String()) - - } else { - return errors.New("expires_in and expires_on are both missing or invalid") - } - } else { - return errors.New("expires_in or expires_on must be present in the response") - } - - return nil -} - // ComputeScope computes the final scopes based on what was granted by the server and // what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted // This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 83d8ead0..295b30bc 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -48,6 +48,7 @@ const ( // UAMI query parameter name miQueryParameterClientId = "client_id" miQueryParameterObjectId = "object_id" + miQueryParameterPrincipalId = "principal_id" miQueryParameterResourceIdIMDS = "msi_res_id" miQueryParameterResourceId = "mi_res_id" @@ -480,7 +481,7 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (* case UserAssignedResourceID: q.Set(miQueryParameterResourceId, string(t)) case UserAssignedObjectID: - q.Set(miQueryParameterObjectId, string(t)) + q.Set(miQueryParameterPrincipalId, string(t)) case systemAssignedValue: default: return nil, fmt.Errorf("unsupported type %T", id) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 205b2671..3be92f2a 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -37,8 +37,8 @@ const ( type SuccessfulResponse struct { AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in"` - ExpiresOn int64 `json:"expires_on"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresOn int64 `json:"expires_on,omitempty"` Resource string `json:"resource"` TokenType string `json:"token_type"` } @@ -48,27 +48,19 @@ type ErrorResponse struct { Desc string `json:"error_description"` } -func getSuccessfulResponse(resource string, doesHaveExpireIn bool) ([]byte, error) { +func getSuccessfulResponse(resource string) ([]byte, error) { var response SuccessfulResponse - if doesHaveExpireIn { - duration := 10 * time.Minute - expiresIn := duration.Seconds() - response = SuccessfulResponse{ - AccessToken: token, - ExpiresIn: int64(expiresIn), - Resource: resource, - TokenType: "Bearer", - } - } else { - println(time.Now().Add(time.Hour).Unix()) - response = SuccessfulResponse{ - AccessToken: token, - ExpiresOn: time.Now().Add(time.Hour).Unix(), - Resource: resource, - TokenType: "Bearer", - } - println(response.ExpiresOn) + + duration := 10 * time.Minute + expiresIn := duration.Seconds() + response = SuccessfulResponse{ + AccessToken: token, + ExpiresIn: int64(expiresIn), + ExpiresOn: time.Now().Add(time.Hour).Unix(), + Resource: resource, + TokenType: "Bearer", } + jsonResponse, err := json.Marshal(response) return jsonResponse, err } @@ -291,7 +283,7 @@ func Test_RetryPolicy_For_AcquireToken(t *testing.T) { })) } if !testCase.expectedFail { - successRespBody, err := getSuccessfulResponse(resource, true) + successRespBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } @@ -393,7 +385,7 @@ func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) { var localUrl *url.URL mockClient := mock.Client{} - responseBody, err := getSuccessfulResponse(resource, true) + responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf(errorFormingJsonResponse, err.Error()) } @@ -485,7 +477,7 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) { var localUrl *url.URL mockClient := mock.Client{} - responseBody, err := getSuccessfulResponse(resource, false) + responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf(errorFormingJsonResponse, err.Error()) } @@ -528,7 +520,7 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) { t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) } case UserAssignedObjectID: - if query.Get(miQueryParameterObjectId) != i.value() { + if query.Get(miQueryParameterPrincipalId) != i.value() { t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) } } @@ -583,7 +575,7 @@ func TestAzureArc(t *testing.T) { localUrl = r.URL })) - responseBody, err := getSuccessfulResponse(resource, true) + responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf(errorFormingJsonResponse, err.Error()) } @@ -754,7 +746,7 @@ func TestAzureArcErrors(t *testing.T) { mock.WithHTTPHeader(headers), ) - responseBody, err := getSuccessfulResponse(resource, true) + responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf(errorFormingJsonResponse, err.Error()) } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index f1ec159d..3bfdaaa3 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -306,12 +306,12 @@ func TestADFSTokenCaching(t *testing.T) { } fakeAT := fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ - AccessToken: "at1", - RefreshToken: "rt", - TokenType: "Bearer", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + AccessToken: "at1", + RefreshToken: "rt", + TokenType: "Bearer", + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, IDToken: accesstokens.IDToken{ ExpirationTime: time.Now().Add(time.Hour).Unix(), Name: "A", diff --git a/apps/tests/benchmarks/confidential.go b/apps/tests/benchmarks/confidential.go index 802bafbb..0cfdf990 100644 --- a/apps/tests/benchmarks/confidential.go +++ b/apps/tests/benchmarks/confidential.go @@ -39,9 +39,9 @@ func fakeClient() (base.Client, error) { return base.New("fake_client_id", "https://fake_authority/fake", &oauth.Client{ AccessTokens: &fake.AccessTokens{ AccessToken: accesstokens.TokenResponse{ - AccessToken: accessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + AccessToken: accessToken, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, }, }, Authority: &fake.Authority{ @@ -88,9 +88,9 @@ func populateTokenCache(client base.Client, params testParams) execTime { // we use this to add a fake token to the cache. // each token has a different scope which is what makes them unique _, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{ - AccessToken: accessToken, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}}, + AccessToken: accessToken, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}}, }, true) if err != nil { panic(err) diff --git a/apps/tests/performance/performance_test.go b/apps/tests/performance/performance_test.go index 66050b72..163c4e03 100644 --- a/apps/tests/performance/performance_test.go +++ b/apps/tests/performance/performance_test.go @@ -52,11 +52,11 @@ func populateCache(users int, tokens int, authParams authority.AuthParams, clien scope := fmt.Sprintf("scope%d", token) _, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{ - AccessToken: fmt.Sprintf("fake_access_token%d", user), - RefreshToken: "fake_refresh_token", - ClientInfo: accesstokens.ClientInfo{UID: "my_uid", UTID: fmt.Sprintf("%dmy_utid", user)}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: []string{scope}}, + AccessToken: fmt.Sprintf("fake_access_token%d", user), + RefreshToken: "fake_refresh_token", + ClientInfo: accesstokens.ClientInfo{UID: "my_uid", UTID: fmt.Sprintf("%dmy_utid", user)}, + ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: []string{scope}}, IDToken: accesstokens.IDToken{ RawToken: "x.e30", },