From cf36585e565112f9b72d2aa0220664563726c973 Mon Sep 17 00:00:00 2001 From: Michael Grosse Huelsewiesche Date: Wed, 13 Sep 2023 16:53:04 -0400 Subject: [PATCH] Refactor of OAuth support with working tests, still in progress --- packages/core/src/callback/index.ts | 15 +- packages/node/src/app/analytics-node.ts | 10 +- packages/node/src/app/settings.ts | 4 +- ...uth-util.test.ts => token-manager.test.ts} | 160 +++++---- packages/node/src/lib/oauth-util.ts | 323 +++++++++--------- packages/node/src/lib/token-manager.ts | 261 ++++++++++++++ .../node/src/plugins/segmentio/publisher.ts | 58 +--- 7 files changed, 562 insertions(+), 269 deletions(-) rename packages/node/src/lib/__tests__/{oauth-util.test.ts => token-manager.test.ts} (51%) create mode 100644 packages/node/src/lib/token-manager.ts diff --git a/packages/core/src/callback/index.ts b/packages/core/src/callback/index.ts index 8bc1156cd..16f58ae23 100644 --- a/packages/core/src/callback/index.ts +++ b/packages/core/src/callback/index.ts @@ -16,8 +16,19 @@ export function pTimeout(promise: Promise, timeout: number): Promise { }) } -export function sleep(timeoutInMs: number): Promise { - return new Promise((resolve) => setTimeout(resolve, timeoutInMs)) +export function sleep( + timeoutInMs: number, + signal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, timeoutInMs) + if (signal) { + signal.addEventListener('abort', () => { + clearTimeout(timeout) + reject(new DOMException('Aborted', 'AbortError')) + }) + } + }) } /** diff --git a/packages/node/src/app/analytics-node.ts b/packages/node/src/app/analytics-node.ts index 16ddd6e2d..4cebe51da 100644 --- a/packages/node/src/app/analytics-node.ts +++ b/packages/node/src/app/analytics-node.ts @@ -56,7 +56,7 @@ export class Analytics extends NodeEmitter implements CoreAnalytics { typeof settings.httpClient === 'function' ? new FetchHTTPClient(settings.httpClient) : settings.httpClient ?? new FetchHTTPClient(), - oauthSettings: settings.oauthSettings, + tokenManagerProps: settings.tokenManagerProps, }, this as NodeEmitter ) @@ -73,14 +73,6 @@ export class Analytics extends NodeEmitter implements CoreAnalytics { return version } - get oauthSettings() { - return this._publisher.oauthSettings - } - - set oauthSettings(value) { - this._publisher.oauthSettings = value - } - /** * Call this method to stop collecting new events and flush all existing events. * This method also waits for any event method-specific callbacks to be triggered, diff --git a/packages/node/src/app/settings.ts b/packages/node/src/app/settings.ts index 14173a753..9fe91553b 100644 --- a/packages/node/src/app/settings.ts +++ b/packages/node/src/app/settings.ts @@ -1,6 +1,6 @@ import { ValidationError } from '@segment/analytics-core' import { HTTPClient, HTTPFetchFn } from '../lib/http-client' -import { OauthSettings } from '../lib/oauth-util' +import { TokenManagerProps } from '../lib/token-manager' export interface AnalyticsSettings { /** @@ -44,7 +44,7 @@ export interface AnalyticsSettings { /** * Set up OAuth2 authentication between the client and Segment's endpoints */ - oauthSettings?: OauthSettings + tokenManagerProps?: TokenManagerProps } export const validateSettings = (settings: AnalyticsSettings) => { diff --git a/packages/node/src/lib/__tests__/oauth-util.test.ts b/packages/node/src/lib/__tests__/token-manager.test.ts similarity index 51% rename from packages/node/src/lib/__tests__/oauth-util.test.ts rename to packages/node/src/lib/__tests__/token-manager.test.ts index 7f0c2ec41..f7a848866 100644 --- a/packages/node/src/lib/__tests__/oauth-util.test.ts +++ b/packages/node/src/lib/__tests__/token-manager.test.ts @@ -1,7 +1,7 @@ -import { RefreshToken, OauthData, OauthSettings } from '../oauth-util' +import { sleep } from '@segment/analytics-core' import { TestFetchClient } from '../../__tests__/test-helpers/create-test-analytics' -import { readFileSync } from 'fs' import { HTTPResponse } from '../http-client' +import { TokenManager, TokenManagerProps } from '../token-manager' const privateKey = Buffer.from(`-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVll7uJaH322IN @@ -53,68 +53,110 @@ const createOAuthError = (overrides: Partial = {}) => { }) as Promise } -const getOauthData = () => { - const oauthSettings = { +const getTokenManager = () => { + const tokenManagerProps = { + httpClient: testClient, + maxRetries: 3, clientId: 'clientId', clientKey: privateKey, keyId: 'keyId', scope: 'scope', authServer: 'http://127.0.0.1:1234', - } as OauthSettings + } as TokenManagerProps - const oauthData = { - httpClient: testClient, - settings: oauthSettings, - maxRetries: 3, - } as unknown as OauthData - return oauthData + return new TokenManager(tokenManagerProps) } -test('OAuth Success', async () => { - fetcher.mockReturnValueOnce( - createOAuthSuccess({ - access_token: 'token', - expires_in: '100', - }) - ) - - const oauthData = getOauthData() - - RefreshToken(oauthData) - await oauthData.refreshPromise - - expect(oauthData.refreshTimer).toBeDefined() - expect(oauthData.refreshPromise).toBeUndefined() - expect(oauthData.token).toBe('token') - expect(fetcher).toHaveBeenCalledTimes(1) -}) - -test('OAuth retry failure', async () => { - fetcher.mockReturnValue(createOAuthError({ status: 425 })) - - const oauthData = getOauthData() - - RefreshToken(oauthData) - await expect(oauthData.refreshPromise).rejects.toThrowError( - 'Retry limit reached - Foo' - ) - - expect(oauthData.refreshTimer).toBeUndefined() - expect(oauthData.refreshPromise).toBeUndefined() - expect(oauthData.token).toBeUndefined() - expect(fetcher).toHaveBeenCalledTimes(3) -}) - -test('OAuth immediate failure', async () => { - fetcher.mockReturnValue(createOAuthError({ status: 400 })) - - const oauthData = getOauthData() - - RefreshToken(oauthData) - await expect(oauthData.refreshPromise).rejects.toThrowError('Foo') - - expect(oauthData.refreshTimer).toBeUndefined() - expect(oauthData.refreshPromise).toBeUndefined() - expect(oauthData.token).toBeUndefined() - expect(fetcher).toHaveBeenCalledTimes(1) -}) +test( + 'OAuth Success', + async () => { + fetcher.mockReturnValueOnce( + createOAuthSuccess({ + access_token: 'token', + expires_in: 100, + }) + ) + + const tokenManager = getTokenManager() + const token = await tokenManager.getAccessToken() + tokenManager.stopPoller() + + expect(tokenManager.isValidToken(token)).toBeTruthy() + expect(token.access_token).toBe('token') + expect(token.expires_in).toBe(100) + expect(fetcher).toHaveBeenCalledTimes(1) + }, + 30 * 1000 +) + +test( + 'OAuth retry failure', + async () => { + fetcher.mockReturnValue(createOAuthError({ status: 425 })) + + const tokenManager = getTokenManager() + + await expect(tokenManager.getAccessToken()).rejects.toThrowError('Foo') + tokenManager.stopPoller() + + expect(fetcher).toHaveBeenCalledTimes(3) + }, + 30 * 1000 +) + +test( + 'OAuth immediate failure', + async () => { + fetcher.mockReturnValue(createOAuthError({ status: 400 })) + + const tokenManager = getTokenManager() + + await expect(tokenManager.getAccessToken()).rejects.toThrowError('Foo') + tokenManager.stopPoller() + + expect(fetcher).toHaveBeenCalledTimes(1) + }, + 30 * 1000 +) + +test( + 'OAuth rate limit', + async () => { + fetcher + .mockReturnValueOnce( + createOAuthError({ + status: 429, + headers: { 'X-RateLimit-Reset': Date.now() + 1000 }, + }) + ) + .mockReturnValueOnce( + createOAuthError({ + status: 429, + headers: { 'X-RateLimit-Reset': Date.now() + 1000 }, + }) + ) + .mockReturnValue( + createOAuthSuccess({ + access_token: 'token', + expires_in: 100, + }) + ) + + const tokenManager = getTokenManager() + + const tokenPromise = tokenManager.getAccessToken() + await sleep(250) + expect(fetcher).toHaveBeenCalledTimes(1) + await sleep(250) + expect(fetcher).toHaveBeenCalledTimes(2) + await sleep(350) + expect(fetcher).toHaveBeenCalledTimes(3) + + const token = await tokenPromise + expect(tokenManager.isValidToken(token)).toBeTruthy() + expect(token.access_token).toBe('token') + expect(token.expires_in).toBe(100) + expect(fetcher).toHaveBeenCalledTimes(3) + }, + 30 * 1000 +) diff --git a/packages/node/src/lib/oauth-util.ts b/packages/node/src/lib/oauth-util.ts index 1ea456a6e..c8aa1ae9d 100644 --- a/packages/node/src/lib/oauth-util.ts +++ b/packages/node/src/lib/oauth-util.ts @@ -1,153 +1,170 @@ -import { SignOptions, sign } from 'jsonwebtoken' -import { HTTPClient, HTTPClientRequest } from './http-client' -import { backoff } from '@segment/analytics-core' - -function sleep(timeoutInMs: number): Promise { - return new Promise((resolve) => setTimeout(resolve, timeoutInMs)) -} - -export interface OauthSettings { - clientId: string - clientKey: Buffer - keyId: string - scope?: string - authServer?: string - issuedAt?: number -} - -export interface OauthData { - httpClient: HTTPClient - settings: OauthSettings - token: string | undefined - refreshPromise: Promise | undefined - refreshTimer: ReturnType | undefined - maxRetries: number -} - -export const RefreshToken = (data: OauthData) => { - clearTimeout(data.refreshTimer) - data.refreshTimer = undefined - if (!data.refreshPromise) { - data.refreshPromise = RefreshTokenAsync(data) - } -} - -export const RefreshTokenAsync = async (data: OauthData) => { - const header = { - alg: 'RS256', - kid: data.settings.keyId, - 'Content-Type': 'application/x-www-form-urlencoded', - } as Record - const jti = Math.floor(Math.random() * 9999).toString() - - const body = { - iss: data.settings.clientId, - sub: data.settings.clientId, - aud: 'https://oauth2.segment.io', - iat: data.settings.issuedAt ?? Math.round(new Date().getTime() / 1000), - exp: Math.round(new Date().getTime() / 1000 + 60), - jti: jti, - } - - const options: SignOptions = { - algorithm: 'RS256', - keyid: data.settings.keyId, - } - - const signedJwt = sign(body, data.settings.clientKey, options) - const scope = data.settings.scope ?? 'tracking_api:write' - - const requestBody = - 'grant_type=client_credentials' + - '&client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer' + - '&client_assertion=' + - signedJwt + - '&scope=' + - scope - - const authServer = - data.settings.authServer ?? 'https://oauth2.segment.build/token' - - const requestOptions: HTTPClientRequest = { - method: 'POST', - url: authServer, - body: requestBody, - headers: header, - httpRequestTimeout: 10000, - } - - const maxAttempts = data.maxRetries - let currentAttempt = 0 - let lastError = '' - while (currentAttempt < maxAttempts) { - currentAttempt++ - try { - const response = await data.httpClient.makeRequest(requestOptions) - - if (response.status === 200) { - let access_token = '' - let expires_in = 0 - const result = await (response.json() as Promise<{ - access_token: string - expires_in: number - }>) - try { - access_token = result.access_token - expires_in = result.expires_in - } catch { - throw new Error('Malformed token response - ' + result) - } - data.refreshTimer = setTimeout( - RefreshToken, - (expires_in * 1000) / 2, - data - ) - data.refreshTimer.unref() - data.token = access_token - data.refreshPromise = undefined - return - } - - // We may be refreshing the token early and still have a valid token. - if ([400, 401, 415].includes(response.status)) { - // Unrecoverable errors - throw new Error(response.statusText) - } else if (response.status == 429) { - // Rate limit, wait until reset timestamp - const rateLimitResetTime = response.headers['X-RateLimit-Reset'] - let rateLimitDiff = 60 - if (rateLimitResetTime) { - rateLimitDiff = - parseInt(rateLimitResetTime) - - Math.round(new Date().getTime() / 1000) + - 5 - } - data.refreshTimer = setTimeout(RefreshToken, rateLimitDiff, data) - data.refreshTimer.unref() - data.refreshPromise = undefined - return - } - - lastError = response.statusText - - // Retry after attempt-based backoff. - await sleep( - backoff({ - attempt: currentAttempt, - minTimeout: 25, - maxTimeout: 1000, - }) - ) - } catch (err) { - clearTimeout(data.refreshTimer) - data.refreshTimer = undefined - data.refreshPromise = undefined - throw err - } - } - // Out of retries - clearTimeout(data.refreshTimer) - data.refreshTimer = undefined - data.refreshPromise = undefined - throw new Error('Retry limit reached - ' + lastError) -} +// import { SignOptions, sign } from 'jsonwebtoken' +// import { FetchHTTPClient, HTTPClient, HTTPClientRequest } from './http-client' +// import { backoff } from '@segment/analytics-core' + +// function sleep(timeoutInMs: number): Promise { +// return new Promise((resolve) => setTimeout(resolve, timeoutInMs)) +// } + +// export interface OauthSettings { +// clientId: string +// clientKey: Buffer +// keyId: string +// scope?: string +// authServer?: string +// issuedAt?: number +// httpClient?: HTTPClient +// maxRetries?: number +// } + +// export class OauthManager { +// clientId: string +// clientKey: Buffer +// keyId: string +// scope: string +// authServer: string +// issuedAt?: number +// httpClient: HTTPClient +// maxRetries: number +// token: string | undefined +// refreshPromise: Promise | undefined +// refreshTimer: ReturnType | undefined + +// constructor(settings: OauthSettings) { +// this.clientId = settings.clientId +// this.clientKey = settings.clientKey +// this.keyId = settings.keyId +// this.scope = settings.scope ?? 'tracking_api:write' +// this.authServer = settings.authServer ?? 'https://oauth2.segment.build/token' +// this.issuedAt = settings.issuedAt +// this.httpClient = settings.httpClient ?? new FetchHTTPClient() +// this.maxRetries = settings.maxRetries ?? 3 +// } + +// RefreshToken = () => { +// clearTimeout(data.refreshTimer) +// data.refreshTimer = undefined +// if (!data.refreshPromise) { +// data.refreshPromise = RefreshTokenAsync(data) +// } +// } + +// RequestToken = async () => { +// const header = { +// alg: 'RS256', +// kid: this.keyId, +// 'Content-Type': 'application/x-www-form-urlencoded', +// } as Record +// const jti = Math.floor(Math.random() * 9999).toString() + +// const body = { +// iss: this.clientId, +// sub: this.clientId, +// aud: 'https://oauth2.segment.io', +// iat: this.issuedAt ?? Math.round(new Date().getTime() / 1000), +// exp: Math.round(new Date().getTime() / 1000 + 60), +// jti: jti, +// } + +// const options: SignOptions = { +// algorithm: 'RS256', +// keyid: this.keyId, +// } + +// const signedJwt = sign(body, this.clientKey, options) +// const scope = this.scope ?? 'tracking_api:write' + +// const requestBody = +// 'grant_type=client_credentials' + +// '&client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer' + +// '&client_assertion=' + +// signedJwt + +// '&scope=' + +// scope + +// const authServer = this.authServer + +// const requestOptions: HTTPClientRequest = { +// method: 'POST', +// url: authServer, +// body: requestBody, +// headers: header, +// httpRequestTimeout: 10000, +// } + +// const maxAttempts = this.maxRetries +// let currentAttempt = 0 +// let lastError = '' +// while (currentAttempt < maxAttempts) { +// currentAttempt++ +// try { +// const response = await this.httpClient.makeRequest(requestOptions) + +// if (response.status === 200) { +// let access_token = '' +// let expires_in = 0 +// const result = await (response.json() as Promise<{ +// access_token: string +// expires_in: number +// }>) +// try { +// access_token = result.access_token +// expires_in = result.expires_in +// } catch { +// throw new Error('Malformed token response - ' + result) +// } +// this.refreshTimer = setTimeout( +// RefreshToken, +// (expires_in * 1000) / 2, +// this +// ) +// this.refreshTimer.unref() +// this.token = access_token +// this.refreshPromise = undefined +// return +// } + +// // We may be refreshing the token early and still have a valid token. +// if ([400, 401, 415].includes(response.status)) { +// // Unrecoverable errors +// throw new Error(response.statusText) +// } else if (response.status == 429) { +// // Rate limit, wait until reset timestamp +// const rateLimitResetTime = response.headers['X-RateLimit-Reset'] +// let rateLimitDiff = 60 +// if (rateLimitResetTime) { +// rateLimitDiff = +// parseInt(rateLimitResetTime) - +// Math.round(new Date().getTime() / 1000) + +// 5 +// } +// data.refreshTimer = setTimeout(RefreshToken, rateLimitDiff, data) +// data.refreshTimer.unref() +// data.refreshPromise = undefined +// return +// } + +// lastError = response.statusText + +// // Retry after attempt-based backoff. +// await sleep( +// backoff({ +// attempt: currentAttempt, +// minTimeout: 25, +// maxTimeout: 1000, +// }) +// ) +// } catch (err) { +// clearTimeout(data.refreshTimer) +// data.refreshTimer = undefined +// data.refreshPromise = undefined +// throw err +// } +// } +// // Out of retries +// clearTimeout(data.refreshTimer) +// data.refreshTimer = undefined +// data.refreshPromise = undefined +// throw new Error('Retry limit reached - ' + lastError) +// } +// } diff --git a/packages/node/src/lib/token-manager.ts b/packages/node/src/lib/token-manager.ts new file mode 100644 index 000000000..b1909f52a --- /dev/null +++ b/packages/node/src/lib/token-manager.ts @@ -0,0 +1,261 @@ +import { uuid } from './uuid' +import { + FetchHTTPClient, + HTTPClient, + HTTPClientRequest, + HTTPResponse, +} from './http-client' +import { SignOptions, sign } from 'jsonwebtoken' +import { Emitter, backoff, sleep } from '@segment/analytics-core' + +type AccessToken = { + access_token: string + expires_in: number +} + +export type TokenManagerProps = { + httpClient: HTTPClient | undefined + maxRetries: number + authServer: string + scope: string + clientId: string + clientKey: Buffer + keyId: string +} + +export class TokenManager { + private alg = 'RS256' as const + private grantType = 'client_credentials' as const + private clientAssertionType = + 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' as const + private clientId: string + private clientKey: Buffer + private keyId: string + private scope: string + private authServer: string + private httpClient: HTTPClient + private maxRetries: number + private clockSkewInSeconds = 0 + + private controller: AbortController + private signal: AbortSignal + + private accessToken?: AccessToken + private isRunning = false + private tokenEmitter = new Emitter<{ + access_token: [{ token: AccessToken } | { error: unknown }] + }>() + + constructor(props: TokenManagerProps) { + this.keyId = props.keyId + this.clientId = props.clientId + this.clientKey = props.clientKey + this.authServer = props.authServer + this.scope = props.scope + this.httpClient = props.httpClient ?? new FetchHTTPClient() + this.maxRetries = props.maxRetries + this.tokenEmitter.on('access_token', (event) => { + if ('token' in event) { + this.accessToken = event.token + } + }) + this.controller = new AbortController() + this.signal = this.controller.signal + } + + async startPoller() { + if (this.isRunning) return + this.isRunning = true + + let retryCount = 0 + let lastError: any + + while (this.isRunning) { + let timeUntilRefreshInMs = 0 + let response: HTTPResponse + + try { + response = await this.requestAccessToken() + } catch (err) { + // Error without a status code - likely networking, retry (backoff or immediately?) + retryCount++ + lastError = err + await sleep( + backoff({ + attempt: retryCount, + minTimeout: 25, + maxTimeout: 1000, + }), + this.signal + ) + continue + } + + // TODO: Calculate clock skew using reponse.headers.Date compared to system time + + // Handle status codes! + if (response.status === 200) { + let body: any + try { + body = await response.json() // TODO: Replace with actual method to get body - needs discussion since different HTTP clients expose this differently (buffers, streams, strings, objects) + } catch (err) { + // Errors reading the body (not parsing) are likely networking issues, we can retry + retryCount++ + lastError = err + //console.log(lastError) + continue + } + let token: AccessToken + try { + const parsedBody = /*JSON.parse(*/ body /*)*/ + // TODO: validate JSON + token = parsedBody + + this.tokenEmitter.emit('access_token', { token }) + + // Reset our failure count + retryCount = 0 + + // Refresh the token after half the expiry time passes + timeUntilRefreshInMs = Math.floor((token.expires_in / 2) * 1000) + } catch (err) { + // Something went really wrong with the body, lets surface an error and try again? + this.tokenEmitter.emit('access_token', { error: err }) + retryCount = 0 + //console.log(err) + + timeUntilRefreshInMs = backoff({ + attempt: retryCount, + minTimeout: 25, + maxTimeout: 1000, + }) + } + } else if (response.status === 429) { + retryCount++ + lastError = response.statusText + //console.log(lastError) + const rateLimitResetTime = parseInt( + response.headers['X-RateLimit-Reset'], + 10 + ) + if (isFinite(rateLimitResetTime)) { + timeUntilRefreshInMs = + (rateLimitResetTime - Date.now()) / 2 + + this.clockSkewInSeconds * 1000 + } else { + timeUntilRefreshInMs = 60 * 1000 + } + } else if ([400, 401, 415].includes(response.status)) { + // Unrecoverable errors + retryCount = 0 + this.tokenEmitter.emit('access_token', { + error: new Error(response.statusText), + }) + //console.log(response.statusText) + this.stopPoller() + return + } else { + retryCount++ + lastError = new Error(response.statusText) + //console.log(lastError) + timeUntilRefreshInMs = backoff({ + attempt: retryCount, + minTimeout: 25, + maxTimeout: 1000, + }) + } + + if (retryCount >= this.maxRetries) { + this.tokenEmitter.emit('access_token', { error: lastError }) + //console.log(lastError) + // TODO: figure out timing and whether to reset retries? + } + //console.log('Sleeping for: ' + timeUntilRefreshInMs) + await sleep(timeUntilRefreshInMs, this.signal) + } + } + + stopPoller() { + // TODO: Use abort controller to end the while loop in startPoller() + if (this.isRunning) { + //console.log('Abork bork bork!') + this.controller.abort() + } + this.isRunning = false + } + + /** + * Solely responsible for building the HTTP request and calling the token service. + */ + private requestAccessToken(): Promise { + const jti = uuid() + const currentUTCInSeconds = Math.round(Date.now() / 1000) + const jwtBody = { + iss: this.clientId, + sub: this.clientId, + aud: 'https://oauth2.segment.io', + iat: currentUTCInSeconds, + exp: currentUTCInSeconds + 60, + jti, + } + + const signingOptions: SignOptions = { + algorithm: this.alg, + keyid: this.keyId, + } + + const signedJwt = sign(jwtBody, this.clientKey, signingOptions) + + const requestBody = `grant_type=${this.grantType}&client_assertion_type=${this.clientAssertionType}&client_assertion=${signedJwt}&scope=${this.scope}` + const accessTokenEndpoint = `${this.authServer}/token` + + const requestOptions: HTTPClientRequest = { + method: 'POST', + url: accessTokenEndpoint, + body: requestBody, + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + httpRequestTimeout: 10000, + } + + //console.log('!!!!!!!!! fetch') + return this.httpClient.makeRequest(requestOptions) + } + + async getAccessToken(): Promise { + //console.log('############################## Access token requested') + // Use the cached token if it is still valid, otherwise wait for a new token. + if (this.isValidToken(this.accessToken)) { + return this.accessToken + } + + // stop poller first in order to make sure that it's not sleeping if we need a token immediately + // Otherwise it could be hours before the expiration time passes normally + this.stopPoller() + + // startPoller needs to be called somewhere, either lazily when a token is first requested, or at instantiation. + // Doing it lazily for this example + this.startPoller().catch(() => {}) + + return new Promise((resolve, reject) => { + this.tokenEmitter.once('access_token', (event) => { + if ('token' in event) { + resolve(event.token) + } else { + reject(event.error) + } + }) + }) + } + + clearToken() { + this.accessToken = undefined + } + + isValidToken(token?: AccessToken): token is AccessToken { + // TODO: Check if it has already expired? + // otherwise this check is pretty much useless + return typeof token !== 'undefined' && token !== null + } +} diff --git a/packages/node/src/plugins/segmentio/publisher.ts b/packages/node/src/plugins/segmentio/publisher.ts index 4cb80b703..92f6d3214 100644 --- a/packages/node/src/plugins/segmentio/publisher.ts +++ b/packages/node/src/plugins/segmentio/publisher.ts @@ -5,7 +5,8 @@ import { extractPromiseParts } from '../../lib/extract-promise-parts' import { ContextBatch } from './context-batch' import { NodeEmitter } from '../../app/emitter' import { HTTPClient, HTTPClientRequest } from '../../lib/http-client' -import { RefreshToken, OauthSettings, OauthData } from '../../lib/oauth-util' +//import { RefreshToken, OauthSettings, OauthData } from '../../lib/oauth-util' +import { TokenManager, TokenManagerProps } from '../../lib/token-manager' function sleep(timeoutInMs: number): Promise { return new Promise((resolve) => setTimeout(resolve, timeoutInMs)) @@ -28,7 +29,7 @@ export interface PublisherProps { httpRequestTimeout?: number disable?: boolean httpClient: HTTPClient - oauthSettings?: OauthSettings + tokenManagerProps?: TokenManagerProps } /** @@ -48,7 +49,7 @@ export class Publisher { private _disable: boolean private _httpClient: HTTPClient private _writeKey: string - private _oauthData: OauthData | undefined + private _tokenManager: TokenManager | undefined constructor( { host, @@ -60,7 +61,7 @@ export class Publisher { httpRequestTimeout, httpClient, disable, - oauthSettings, + tokenManagerProps, }: PublisherProps, emitter: NodeEmitter ) { @@ -77,31 +78,9 @@ export class Publisher { this._httpClient = httpClient this._writeKey = writeKey - if (oauthSettings != null) { - this.oauthSettings = oauthSettings - } - } - - get oauthSettings() { - return this._oauthData?.settings - } - - set oauthSettings(value) { - if (value) { - if (this._oauthData) { - this._oauthData.settings = value - } else { - this._oauthData = { - httpClient: this._httpClient, - settings: value, - maxRetries: this._maxRetries, - } as OauthData - } - RefreshToken(this._oauthData) - } else { - throw new Error( - "OAuth settings can't be removed, create a new analytics object instead." - ) + if (tokenManagerProps != null) { + tokenManagerProps.httpClient ??= httpClient + this._tokenManager = new TokenManager(tokenManagerProps) } } @@ -229,18 +208,10 @@ export class Publisher { } let authString = undefined - if (this._oauthData?.settings) { - if (!this._oauthData.token) { - if (!this._oauthData.refreshPromise) { - RefreshToken(this._oauthData) - } - await this._oauthData.refreshPromise - if (!this._oauthData.token) { - // If we don't have a token then authorization failed after multiple attempts - continue // ??? - } - } - authString = `Bearer ${this._oauthData.token}` + if (this._tokenManager) { + authString = `Bearer ${ + (await this._tokenManager.getAccessToken()).access_token + }` } let headers @@ -279,14 +250,13 @@ export class Publisher { batch.resolveEvents() return } else if ( - this._oauthData && - this._oauthData.settings && + this._tokenManager && (response.status === 400 || response.status === 401 || response.status === 403) ) { // Retry with a new OAuth token if we have OAuth data - RefreshToken(this._oauthData) + this._tokenManager.clearToken() failureReason = new Error( `[${response.status}] ${response.statusText}` )