Skip to content

Commit

Permalink
feat(api-graphql): AppSync realtime - include custom headers as query…
Browse files Browse the repository at this point in the history
… string params (#13735)
  • Loading branch information
iartemiev authored Aug 22, 2024
1 parent 5224dc2 commit 5647497
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 3 deletions.
56 changes: 56 additions & 0 deletions packages/api-graphql/__tests__/GraphQLAPI.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1558,4 +1558,60 @@ describe('API test', () => {
);
});
});

test('request level custom headers are applied to query string', async () => {
Amplify.configure({
API: {
GraphQL: {
defaultAuthMode: 'lambda',
endpoint:
'https://testaccounturl123456789123.appsync-api.us-east-1.amazonaws.com/graphql',
region: 'local-host-h4x',
},
},
});

let done: Function;
const mockedFnHasBeenCalled = new Promise(res => (done = res));

const spyon_appsync_realtime = jest
.spyOn(
AWSAppSyncRealTimeProvider.prototype as any,
'_initializeRetryableHandshake',
)
.mockImplementation(
jest.fn(() => {
done(); // resolve promise when called
}) as any,
);

const query = /* GraphQL */ `
subscription SubscribeToEventComments {
subscribeToEventComments {
eventId
}
}
`;

const resolvedUrl =
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBdXRob3JpemF0aW9uIjoiYWJjMTIzNDUiLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ==&payload=e30=&x-amz-user-agent=aws-amplify%2F6.4.0%20api%2F1%20framework%2F2&ex-machina=is%20a%20good%20movie';

(
client.graphql(
{ query },
{
'x-amz-user-agent': 'aws-amplify/6.4.0 api/1 framework/2',
'ex-machina': 'is a good movie',
// This should NOT get included in the querystring
Authorization: 'abc12345',
},
) as unknown as Observable<object>
).subscribe();

await mockedFnHasBeenCalled;

expect(spyon_appsync_realtime).toHaveBeenCalledTimes(1);
const subscribeOptions = spyon_appsync_realtime.mock.calls[0][0];
expect(subscribeOptions).toBe(resolvedUrl);
});
});
60 changes: 59 additions & 1 deletion packages/api-graphql/__tests__/internals/generateClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ import { Amplify, AmplifyClassV6 } from '@aws-amplify/core';
import { generateClient } from '../../src/internals';
import configFixture from '../fixtures/modeled/amplifyconfiguration';
import { Schema } from '../fixtures/modeled/schema';
import { from } from 'rxjs';
import { Observable, from } from 'rxjs';
import {
normalizePostGraphqlCalls,
expectSubWithHeaders,
expectSubWithHeadersFn,
expectSubWithlibraryConfigHeaders,
mockApiResponse,
} from '../utils/index';
import { AWSAppSyncRealTimeProvider } from '../../src/Providers/AWSAppSyncRealTimeProvider';

const serverManagedFields = {
id: 'some-id',
Expand Down Expand Up @@ -332,6 +333,30 @@ describe('generateClient', () => {
expect(normalizePostGraphqlCalls(spy)).toMatchSnapshot();
});

test('with custom client headers - graphql', async () => {
const headers = {
'client-header': 'should exist',
};

const client = generateClient<Schema>({
amplify: Amplify,
headers,
});

await client.graphql({
query: /* GraphQL */ `
query listPosts {
id
}
`,
});

const receivedArgs = normalizePostGraphqlCalls(spy)[0][1];
const receivedHeaders = receivedArgs.options.headers;

expect(receivedHeaders).toEqual(expect.objectContaining(headers));
});

test('with custom client header functions', async () => {
const client = generateClient<Schema>({
amplify: Amplify,
Expand Down Expand Up @@ -495,6 +520,39 @@ describe('generateClient', () => {
});
});

test('with client-level custom headers', done => {
const customHeaders = {
'subscription-header': 'should-exist',
};

const client = generateClient<Schema>({
amplify: Amplify,
headers: customHeaders,
});

const spy = jest.fn(() => from([graphqlMessage]));
(raw.GraphQLAPI as any).appSyncRealTime = { subscribe: spy };

client.models.Note.onCreate({
filter: graphqlVariables.filter,
}).subscribe({
next(value) {
expectSubWithHeaders(
spy,
'onCreateNote',
graphqlVariables,
customHeaders,
);
expect(value).toEqual(expect.objectContaining(noteToSend));
done();
},
error(error) {
expect(error).toBeUndefined();
done('bad news!');
},
});
});

test('with a custom header function', done => {
const customHeaders = {
'subscription-header-function': 'should-return-this-header',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,44 @@ export class AWSAppSyncRealTimeProvider {
}
}

/**
* Strips out `Authorization` header if present
*/
private _extractNonAuthHeaders(
headers?: AWSAppSyncRealTimeProviderOptions['additionalCustomHeaders'],
): Record<string, string> {
if (!headers) {
return {};
}

if ('Authorization' in headers) {
const { Authorization: _, ...nonAuthHeaders } = headers;

return nonAuthHeaders;
}

return headers;
}

/**
*
* @param headers - http headers
* @returns query string of uri-encoded parameters derived from custom headers
*/
private _queryStringFromCustomHeaders(
headers?: AWSAppSyncRealTimeProviderOptions['additionalCustomHeaders'],
): string {
const nonAuthHeaders = this._extractNonAuthHeaders(headers);

const queryParams: string[] = Object.entries(nonAuthHeaders).map(
([key, val]) => `${encodeURIComponent(key)}=${encodeURIComponent(val)}`,
);

const queryString = queryParams.join('&');

return queryString;
}

private _initializeWebSocketConnection({
appSyncGraphqlEndpoint,
authenticationType,
Expand Down Expand Up @@ -749,6 +787,10 @@ export class AWSAppSyncRealTimeProvider {

const payloadQs = base64Encoder.convert(payloadString);

const queryString = this._queryStringFromCustomHeaders(
additionalCustomHeaders,
);

let discoverableEndpoint = appSyncGraphqlEndpoint ?? '';

if (this.isCustomDomain(discoverableEndpoint)) {
Expand All @@ -766,7 +808,11 @@ export class AWSAppSyncRealTimeProvider {
.replace('https://', protocol)
.replace('http://', protocol);

const awsRealTimeUrl = `${discoverableEndpoint}?header=${headerQs}&payload=${payloadQs}`;
let awsRealTimeUrl = `${discoverableEndpoint}?header=${headerQs}&payload=${payloadQs}`;

if (queryString !== '') {
awsRealTimeUrl += `&${queryString}`;
}

await this._initializeRetryableHandshake(awsRealTimeUrl);

Expand Down
3 changes: 2 additions & 1 deletion packages/api-graphql/src/internals/v6.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export function graphql<
const internals = getInternals(this as any);
options.authMode = options.authMode || internals.authMode;
options.authToken = options.authToken || internals.authToken;
const headers = additionalHeaders || internals.headers;

/**
* The correctness of these typings depends on correct string branding or overrides.
Expand All @@ -116,7 +117,7 @@ export function graphql<
// TODO: move V6Client back into this package?
internals.amplify as any,
options,
additionalHeaders,
headers,
);

return result as any;
Expand Down

0 comments on commit 5647497

Please sign in to comment.