Skip to content

Commit

Permalink
Added support for expires_in
Browse files Browse the repository at this point in the history
Added support for expires_in
different time formats for expire_in
support for principle_id for app service in managed identity
  • Loading branch information
4gust committed Jan 6, 2025
1 parent 818cdc9 commit 9484ecb
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 192 deletions.
50 changes: 25 additions & 25 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
} {
t.Run("", func(t *testing.T) {
tr := accesstokens.TokenResponse{
AccessToken: token,
RefreshToken: refresh,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
AccessToken: token,
RefreshToken: refresh,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
PreferredUsername: params.preferredUsername,
UPN: params.upn,
Expand Down Expand Up @@ -461,12 +461,12 @@ func TestADFSTokenCaching(t *testing.T) {
}
fakeAT := fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: "at1",
RefreshToken: "rt",
TokenType: "bearer",
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
AccessToken: "at1",
RefreshToken: "rt",
TokenType: "bearer",
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
ExpirationTime: time.Now().Add(time.Hour).Unix(),
Name: "A",
Expand Down Expand Up @@ -593,9 +593,9 @@ func TestNewCredFromCert(t *testing.T) {
}
t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) {
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
AccessToken: token,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred, fakeAuthority, opts...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -743,7 +743,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
t.Fatal("token provider wasn't invoked")
}
if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v)
t.Fatalf("expected ExpiresOnCalculated ~= %d seconds, got %d", expiresIn, v)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
Expand Down Expand Up @@ -1383,11 +1383,11 @@ func TestWithAuthenticationScheme(t *testing.T) {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "TokenType",
AccessToken: token,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "TokenType",
}, cred, fakeAuthority)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1423,11 +1423,11 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
AccessToken: token,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant)
if err != nil {
t.Fatal(err)
Expand Down
9 changes: 8 additions & 1 deletion apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,21 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
Account: account,
IDToken: tokenResponse.IDToken,
AccessToken: tokenResponse.AccessToken,
ExpiresOn: tokenResponse.ExpiresOn.T,
ExpiresOn: getExpiryTime(tokenResponse),
GrantedScopes: tokenResponse.GrantedScopes.Slice,
Metadata: AuthResultMetadata{
TokenSource: IdentityProvider,
},
}, nil
}

func getExpiryTime(tokenResponse accesstokens.TokenResponse) time.Time {
if tokenResponse.ExpiresOnCalculated.T.IsZero() || tokenResponse.ExpiresOnCalculated.T.Equal(time.Unix(0, 0)) {
return tokenResponse.ExpiresOn.T
}
return tokenResponse.ExpiresOnCalculated.T
}

// Client is a base client that provides access to common methods and primatives that
// can be used by multiple clients.
type Client struct {
Expand Down
48 changes: 24 additions & 24 deletions apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func fakeClient(t *testing.T, opts ...Option) Client {
}
client.Token.AccessTokens = &fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
FamilyID: "family-id",
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
AccessToken: fakeAccessToken,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
FamilyID: "family-id",
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
}
client.Token.Authority = &fake.Authority{
Expand Down Expand Up @@ -133,12 +133,12 @@ func TestAcquireTokenSilentScopes(t *testing.T) {
AuthnScheme: &authority.BearerAuthenticationScheme{},
},
accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
AccessToken: fakeAccessToken,
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
)
storage.FakeValidate = nil
Expand Down Expand Up @@ -177,10 +177,10 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) {
AuthnScheme: &authority.BearerAuthenticationScheme{},
},
accesstokens.TokenResponse{
AccessToken: expectedToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
TokenType: "Bearer",
AccessToken: expectedToken,
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
TokenType: "Bearer",
},
)
if err != nil {
Expand Down Expand Up @@ -334,10 +334,10 @@ func TestCreateAuthenticationResult(t *testing.T) {
{
desc: "no declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: nil,
AccessToken: "accessToken",
ExpiresOnCalculated: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: nil,
},
want: AuthResult{
AccessToken: "accessToken",
Expand All @@ -352,10 +352,10 @@ func TestCreateAuthenticationResult(t *testing.T) {
{
desc: "declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: []string{"openid"},
AccessToken: "accessToken",
ExpiresOnCalculated: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: []string{"openid"},
},
err: true,
},
Expand Down
9 changes: 8 additions & 1 deletion apps/internal/base/storage/partitioned_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
realm,
clientID,
cachedAt,
tokenResponse.ExpiresOn.T,
getExpiryTime(tokenResponse),
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
Expand Down Expand Up @@ -177,6 +177,13 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
return account, nil
}

func getExpiryTime(tokenResponse accesstokens.TokenResponse) time.Time {
if tokenResponse.ExpiresOnCalculated.T.IsZero() || tokenResponse.ExpiresOnCalculated.T.Equal(time.Unix(0, 0)) {
return tokenResponse.ExpiresOn.T
}
return tokenResponse.ExpiresOnCalculated.T
}

func (m *PartitionedManager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
md, err := m.aadMetadataFromCache(ctx, authorityInfo)
if err != nil {
Expand Down
28 changes: 14 additions & 14 deletions apps/internal/base/storage/partitioned_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ func TestOBOAccessTokenScopes(t *testing.T) {
_, err := mgr.Write(
ap,
accesstokens.TokenResponse{
AccessToken: scope[0] + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scope},
IDToken: idt,
RefreshToken: upn + "-rt",
TokenType: "Bearer",
AccessToken: scope[0] + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scope},
IDToken: idt,
RefreshToken: upn + "-rt",
TokenType: "Bearer",
},
)
if err != nil {
Expand Down Expand Up @@ -119,13 +119,13 @@ func TestOBOPartitioning(t *testing.T) {
account, err := mgr.Write(
authParams[i],
accesstokens.TokenResponse{
AccessToken: upn + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scopes},
IDToken: idt,
RefreshToken: upn + "-rt",
TokenType: "Bearer",
AccessToken: upn + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOnCalculated: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scopes},
IDToken: idt,
RefreshToken: upn + "-rt",
TokenType: "Bearer",
},
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion apps/internal/base/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
realm,
clientID,
cachedAt,
tokenResponse.ExpiresOn.T,
getExpiryTime(tokenResponse),
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
Expand Down
18 changes: 9 additions & 9 deletions apps/internal/base/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1009,15 +1009,15 @@ func TestWrite(t *testing.T) {
}
expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)}
tokenResponse := accesstokens.TokenResponse{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
IDToken: idToken,
FamilyID: "fid",
ClientInfo: clientInfo,
GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}},
ExpiresOn: expiresOn,
ExtExpiresOn: internalTime.DurationTime{T: now},
TokenType: "Bearer",
AccessToken: "accessToken",
RefreshToken: "refreshToken",
IDToken: idToken,
FamilyID: "fid",
ClientInfo: clientInfo,
GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}},
ExpiresOnCalculated: expiresOn,
ExtExpiresOn: internalTime.DurationTime{T: now},
TokenType: "Bearer",
}
authInfo := authority.Info{Host: "env", Tenant: "realm", AuthorityType: accAuth}
authParams := authority.AuthParams{
Expand Down
49 changes: 43 additions & 6 deletions apps/internal/json/types/time/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package time

import (
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -40,9 +41,9 @@ func (u *Unix) UnmarshalJSON(b []byte) error {
// of a duration from now into a time.Time object.
// Note: I'm not sure this is the best way to do this. What happens is we get a field
// called "expires_in" that represents the seconds from now that this expires. We
// turn that into a time we call .ExpiresOn. But maybe we should be recording
// turn that into a time we call .ExpiresOnCalculated. But maybe we should be recording
// when the token was received at .TokenRecieved and .ExpiresIn should remain as a duration.
// Then we could have a method called ExpiresOn(). Honestly, the whole thing is
// Then we could have a method called ExpiresOnCalculated(). Honestly, the whole thing is
// bad because the server doesn't return a concrete time. I think this is
// cleaner, but its not great either.
type DurationTime struct {
Expand All @@ -59,12 +60,48 @@ func (d DurationTime) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%d", int64(dt*time.Second))), nil
}

// UnmarshalJSON implements encoding/json.UnmarshalJSON().
func (d *DurationTime) UnmarshalJSON(b []byte) error {
i, err := strconv.Atoi(strings.Trim(string(b), `"`))
// UnmarshalJSON custom unmarshaler for DurationTime
func (t *DurationTime) UnmarshalJSON(b []byte) error {
// Remove the quotes around the JSON string
str := strings.Trim(string(b), `"`)

// Try parsing as Unix timestamp (seconds since the Unix epoch)
if len(str) == 10 {
if unixTimestamp, err := strconv.ParseInt(str, 10, 64); err == nil {
t.T = time.Unix(unixTimestamp, 0)
return nil
}
}

// Try parsing as ISO 8601 format (e.g., "2024-10-18T19:51:37.0000000+00:00")
iso8601Layout := "2006-01-02T15:04:05.9999999-07:00"
if parsedTime, err := time.Parse(iso8601Layout, str); err == nil {
t.T = parsedTime
return nil
}

// Try parsing as MM/dd/yyyy HH:mm:ss format (e.g., "10/18/2024 19:51:37")
// Create regex pattern for MM/dd/yyyy HH:mm:ss
mmddyyyyLayout := `^(\d{2})/(\d{2})/(\d{4}) (\d{2}):(\d{2}):(\d{2})$`
if matched, _ := regexp.MatchString(mmddyyyyLayout, str); matched {
parsedTime, err := time.Parse("01/02/2006 15:04:05", str)
if err == nil {
t.T = parsedTime
return nil
}
}

// Try parsing as yyyy-MM-dd HH:mm:ss format (e.g., "2024-10-18 19:51:37")
if parsedTime, err := time.Parse("2006-01-02 15:04:05", str); err == nil {
t.T = parsedTime
return nil
}

i, err := strconv.Atoi(str)
if err != nil {
return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err)
}
d.T = time.Now().Add(time.Duration(i) * time.Second)
t.T = time.Now().Add(time.Duration(i) * time.Second)
return nil

}
2 changes: 1 addition & 1 deletion apps/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams
return accesstokens.TokenResponse{
TokenType: authParams.AuthnScheme.AccessTokenType(),
AccessToken: tr.AccessToken,
ExpiresOn: internalTime.DurationTime{
ExpiresOnCalculated: internalTime.DurationTime{
T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second),
},
GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
Expand Down
Loading

0 comments on commit 9484ecb

Please sign in to comment.