From 6bd7eaa3d00e4c9acaf67e6e2278f2ad29a0c065 Mon Sep 17 00:00:00 2001 From: Deyaaeldeen Almahallawi Date: Thu, 15 Aug 2024 23:34:10 +0000 Subject: [PATCH] [Azure] Refresh AAD token on retry --- src/index.ts | 2 +- tests/lib/azure.test.ts | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/index.ts b/src/index.ts index 5f7dffd67..37cd6dbf2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -485,7 +485,7 @@ export class AzureOpenAI extends OpenAI { } protected override async prepareOptions(opts: Core.FinalRequestOptions): Promise { - if (opts.headers?.['Authorization'] || opts.headers?.['api-key']) { + if (opts.headers?.['api-key']) { return super.prepareOptions(opts); } const token = await this._getAzureADToken(); diff --git a/tests/lib/azure.test.ts b/tests/lib/azure.test.ts index 06ca1d464..1fd7782f2 100644 --- a/tests/lib/azure.test.ts +++ b/tests/lib/azure.test.ts @@ -254,6 +254,43 @@ describe('instantiate azure client', () => { /The `apiKey` and `azureADTokenProvider` arguments are mutually exclusive; only one can be passed at a time./, ); }); + + test('AAD token is refreshed', async () => { + let fail = true; + const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise => { + if (fail) { + fail = false; + return new Response(undefined, { + status: 429, + headers: { + 'Retry-After': '0.1', + }, + }); + } + return new Response( + JSON.stringify({ auth: (req?.headers as Record)['authorization'] }), + { headers: { 'content-type': 'application/json' } }, + ); + }; + let counter = 0; + async function azureADTokenProvider() { + return `token-${counter++}`; + } + const client = new AzureOpenAI({ + baseURL: 'http://localhost:5000/', + azureADTokenProvider, + apiVersion, + fetch: testFetch, + }); + expect( + await client.chat.completions.create({ + model, + messages: [{ role: 'system', content: 'Hello' }], + }), + ).toStrictEqual({ + auth: 'Bearer token-1', + }); + }); }); test('with endpoint', () => {