diff --git a/packages/amplify-data-construct/.jsii b/packages/amplify-data-construct/.jsii index 2db7108892..328098c1c2 100644 --- a/packages/amplify-data-construct/.jsii +++ b/packages/amplify-data-construct/.jsii @@ -6,7 +6,7 @@ ] }, "bundled": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "^0.2.0", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.2", @@ -4027,5 +4027,5 @@ }, "types": {}, "version": "1.10.2", - "fingerprint": "naLvPjdr9z9wuDIK2e15PtMBvodNvsprfPj2tyr4FaA=" + "fingerprint": "jyWXx9H104F/Z3YgGPhDKqRig51p9ImclAMN35IID8o=" } \ No newline at end of file diff --git a/packages/amplify-data-construct/package.json b/packages/amplify-data-construct/package.json index cdc8c53c16..e1a646b78f 100644 --- a/packages/amplify-data-construct/package.json +++ b/packages/amplify-data-construct/package.json @@ -157,7 +157,7 @@ "semver" ], "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "^0.2.0", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-api-construct": "1.14.0", diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/API.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/API.ts index ed262bb185..2ce7d3e80c 100644 --- a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/API.ts +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/API.ts @@ -5,7 +5,6 @@ export type ConversationMessagePirateChat = { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array | null; content?: Array | null; conversation?: ConversationPirateChat | null; conversationId: string; @@ -285,7 +284,6 @@ export type ModelConversationMessagePirateChatConditionInput = { export type CreateConversationMessagePirateChatInput = { aiContext?: string | null; - assistantContent?: Array | null; content?: Array | null; conversationId: string; id?: string | null; @@ -386,10 +384,6 @@ export type GetConversationMessagePirateChatQuery = { getConversationMessagePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; @@ -488,10 +482,6 @@ export type CreateAssistantResponsePirateChatMutation = { createAssistantResponsePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; @@ -526,10 +516,6 @@ export type CreateConversationMessagePirateChatMutation = { createConversationMessagePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; @@ -585,10 +571,6 @@ export type DeleteConversationMessagePirateChatMutation = { deleteConversationMessagePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; @@ -659,10 +641,6 @@ export type PirateChatMutation = { __typename: 'ToolConfiguration'; } | null; updatedAt?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; conversation?: { __typename: 'ConversationPirateChat'; createdAt: string; @@ -683,10 +661,6 @@ export type OnCreateAssistantResponsePirateChatSubscription = { onCreateAssistantResponsePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; @@ -721,10 +695,6 @@ export type OnCreateConversationMessagePirateChatSubscription = { onCreateConversationMessagePirateChat?: { __typename: 'ConversationMessagePirateChat'; aiContext?: string | null; - assistantContent?: Array<{ - __typename: 'ContentBlock'; - text?: string | null; - } | null> | null; content?: Array<{ __typename: 'ContentBlock'; text?: string | null; diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/mutations.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/mutations.ts index 084cacf909..0370ddbb53 100644 --- a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/mutations.ts +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/mutations.ts @@ -13,10 +13,6 @@ export const createAssistantResponsePirateChat = /* GraphQL */ `mutation CreateA ) { createAssistantResponsePirateChat(input: $input) { aiContext - assistantContent { - text - __typename - } content { text __typename @@ -49,10 +45,6 @@ export const createConversationMessagePirateChat = /* GraphQL */ `mutation Creat ) { createConversationMessagePirateChat(condition: $condition, input: $input) { aiContext - assistantContent { - text - __typename - } content { text __typename @@ -104,10 +96,6 @@ export const deleteConversationMessagePirateChat = /* GraphQL */ `mutation Delet ) { deleteConversationMessagePirateChat(condition: $condition, input: $input) { aiContext - assistantContent { - text - __typename - } content { text __typename @@ -181,10 +169,6 @@ export const pirateChat = /* GraphQL */ `mutation PirateChat( updatedAt ... on ConversationMessagePirateChat { - assistantContent { - text - __typename - } conversation { createdAt id diff --git a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/queries.ts b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/queries.ts index 59ddfd0114..3f63a5d645 100644 --- a/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/queries.ts +++ b/packages/amplify-graphql-api-construct-tests/src/__tests__/conversations/graphql/queries.ts @@ -11,10 +11,6 @@ type GeneratedQuery = string & { export const getConversationMessagePirateChat = /* GraphQL */ `query GetConversationMessagePirateChat($id: ID!) { getConversationMessagePirateChat(id: $id) { aiContext - assistantContent { - text - __typename - } content { text __typename diff --git a/packages/amplify-graphql-api-construct/.jsii b/packages/amplify-graphql-api-construct/.jsii index e8c89a0f8a..d00875fe9b 100644 --- a/packages/amplify-graphql-api-construct/.jsii +++ b/packages/amplify-graphql-api-construct/.jsii @@ -6,7 +6,7 @@ ] }, "bundled": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "^0.2.0", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.2", @@ -8959,5 +8959,5 @@ } }, "version": "1.14.0", - "fingerprint": "u4jHwzwBKRAlcky7o7NnOthEdgdtKR1CtWBsGA5tXq4=" + "fingerprint": "BFQM638XQPsnf3jPECgT+5lULFJd0VyEXzkdPElCSAA=" } \ No newline at end of file diff --git a/packages/amplify-graphql-api-construct/package.json b/packages/amplify-graphql-api-construct/package.json index e8b51e2717..236e443518 100644 --- a/packages/amplify-graphql-api-construct/package.json +++ b/packages/amplify-graphql-api-construct/package.json @@ -158,7 +158,7 @@ "semver" ], "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "^0.2.0", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.2", diff --git a/packages/amplify-graphql-conversation-transformer/package.json b/packages/amplify-graphql-conversation-transformer/package.json index 29c03d1359..44821cb42f 100644 --- a/packages/amplify-graphql-conversation-transformer/package.json +++ b/packages/amplify-graphql-conversation-transformer/package.json @@ -24,7 +24,7 @@ "extract-api": "ts-node ../../scripts/extract-api.ts" }, "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "^0.2.0", "@aws-amplify/graphql-directives": "2.2.0", "@aws-amplify/graphql-index-transformer": "3.0.4", "@aws-amplify/graphql-model-transformer": "3.0.4", diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap index 3747205910..fc11a5bf4e 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap @@ -87,11 +87,12 @@ export function response(ctx) { exports[`ConversationTransformer valid schemas should transform conversation route with inference configuration 3`] = ` "export function request(ctx) { const { authFilter } = ctx.stash; + const { conversationId } = ctx.args; const query = { expression: 'id = :id', expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId, + ':id': conversationId, }), }; @@ -126,7 +127,7 @@ export function request(ctx) { const args = ctx.stash.transformedArgs ?? ctx.args; const defaultValues = ctx.stash.defaultValues ?? {}; const message = { - __typename: 'ConversationMessagepirateChat', + __typename: 'ConversationMessagePirateChat', role: 'user', ...args, ...defaultValues, @@ -150,13 +151,12 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "import { util } from '@aws-appsync/utils'; export function request(ctx) { - const { args, request, prev } = ctx; + const { args, request } = ctx; const { graphqlApiEndpoint } = ctx.stash; const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const messages = prev.result.items; const responseMutation = { name: 'createAssistantResponsePirateChat', inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', @@ -176,6 +176,14 @@ export function request(ctx) { clientTools }; + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + const authHeader = request.headers['authorization']; const payload = { conversationId: args.conversationId, @@ -184,7 +192,7 @@ export function request(ctx) { graphqlApiEndpoint, modelConfiguration, request: { headers: { authorization: authHeader } }, - messages, + messageHistoryQuery, toolsConfiguration, }; @@ -205,6 +213,8 @@ export function response(ctx) { conversationId: ctx.args.conversationId, role: 'user', content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, createdAt: ctx.stash.defaultValues.createdAt, updatedAt: ctx.stash.defaultValues.updatedAt, }; @@ -300,11 +310,12 @@ export function response(ctx) { exports[`ConversationTransformer valid schemas should transform conversation route with model query tool 3`] = ` "export function request(ctx) { const { authFilter } = ctx.stash; + const { conversationId } = ctx.args; const query = { expression: 'id = :id', expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId, + ':id': conversationId, }), }; @@ -339,7 +350,7 @@ export function request(ctx) { const args = ctx.stash.transformedArgs ?? ctx.args; const defaultValues = ctx.stash.defaultValues ?? {}; const message = { - __typename: 'ConversationMessagepirateChat', + __typename: 'ConversationMessagePirateChat', role: 'user', ...args, ...defaultValues, @@ -363,13 +374,12 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "import { util } from '@aws-appsync/utils'; export function request(ctx) { - const { args, request, prev } = ctx; + const { args, request } = ctx; const { graphqlApiEndpoint } = ctx.stash; const toolDefinitions = {"tools":[{"name":"listTodos","description":"lists todos","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { content isDone id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listTodos"}}]}; const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const messages = prev.result.items; const responseMutation = { name: 'createAssistantResponsePirateChat', inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', @@ -391,6 +401,14 @@ export function request(ctx) { clientTools, }; + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + const authHeader = request.headers['authorization']; const payload = { conversationId: args.conversationId, @@ -399,7 +417,7 @@ export function request(ctx) { graphqlApiEndpoint, modelConfiguration, request: { headers: { authorization: authHeader } }, - messages, + messageHistoryQuery, toolsConfiguration, }; @@ -420,6 +438,8 @@ export function response(ctx) { conversationId: ctx.args.conversationId, role: 'user', content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, createdAt: ctx.stash.defaultValues.createdAt, updatedAt: ctx.stash.defaultValues.updatedAt, }; @@ -515,11 +535,12 @@ export function response(ctx) { exports[`ConversationTransformer valid schemas should transform conversation route with model query tool including relationships 3`] = ` "export function request(ctx) { const { authFilter } = ctx.stash; + const { conversationId } = ctx.args; const query = { expression: 'id = :id', expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId, + ':id': conversationId, }), }; @@ -554,7 +575,7 @@ export function request(ctx) { const args = ctx.stash.transformedArgs ?? ctx.args; const defaultValues = ctx.stash.defaultValues ?? {}; const message = { - __typename: 'ConversationMessagepirateChat', + __typename: 'ConversationMessagePirateChat', role: 'user', ...args, ...defaultValues, @@ -578,13 +599,12 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "import { util } from '@aws-appsync/utils'; export function request(ctx) { - const { args, request, prev } = ctx; + const { args, request } = ctx; const { graphqlApiEndpoint } = ctx.stash; const toolDefinitions = {"tools":[{"name":"listCustomers","description":"Provides data about the customer sending a message","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { name email activeCart { products { name price } customerId id createdAt updatedAt owner } orderHistory { items { products { name price } customerId id createdAt updatedAt owner } nextToken } id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listCustomers"}}]}; const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const messages = prev.result.items; const responseMutation = { name: 'createAssistantResponsePirateChat', inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', @@ -606,6 +626,14 @@ export function request(ctx) { clientTools, }; + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + const authHeader = request.headers['authorization']; const payload = { conversationId: args.conversationId, @@ -614,7 +642,7 @@ export function request(ctx) { graphqlApiEndpoint, modelConfiguration, request: { headers: { authorization: authHeader } }, - messages, + messageHistoryQuery, toolsConfiguration, }; @@ -635,6 +663,8 @@ export function response(ctx) { conversationId: ctx.args.conversationId, role: 'user', content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, createdAt: ctx.stash.defaultValues.createdAt, updatedAt: ctx.stash.defaultValues.updatedAt, }; @@ -730,11 +760,12 @@ export function response(ctx) { exports[`ConversationTransformer valid schemas should transform conversation route with query tools 3`] = ` "export function request(ctx) { const { authFilter } = ctx.stash; + const { conversationId } = ctx.args; const query = { expression: 'id = :id', expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId, + ':id': conversationId, }), }; @@ -769,7 +800,7 @@ export function request(ctx) { const args = ctx.stash.transformedArgs ?? ctx.args; const defaultValues = ctx.stash.defaultValues ?? {}; const message = { - __typename: 'ConversationMessagepirateChat', + __typename: 'ConversationMessagePirateChat', role: 'user', ...args, ...defaultValues, @@ -793,13 +824,12 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "import { util } from '@aws-appsync/utils'; export function request(ctx) { - const { args, request, prev } = ctx; + const { args, request } = ctx; const { graphqlApiEndpoint } = ctx.stash; const toolDefinitions = {"tools":[{"name":"getTemperature","description":"does a thing","inputSchema":{"json":{"type":"object","properties":{"city":{"type":"string","description":"A UTF-8 character sequence."}},"required":["city"]}},"graphqlRequestInputDescriptor":{"selectionSet":"value unit","propertyTypes":{"city":"String!"},"queryName":"getTemperature"}},{"name":"plus","description":"does a different thing","inputSchema":{"json":{"type":"object","properties":{"a":{"type":"number","description":"A signed 32-bit integer value."},"b":{"type":"number","description":"A signed 32-bit integer value."}},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"","propertyTypes":{"a":"Int","b":"Int"},"queryName":"plus"}}]}; const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const messages = prev.result.items; const responseMutation = { name: 'createAssistantResponsePirateChat', inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', @@ -821,6 +851,14 @@ export function request(ctx) { clientTools, }; + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + const authHeader = request.headers['authorization']; const payload = { conversationId: args.conversationId, @@ -829,7 +867,7 @@ export function request(ctx) { graphqlApiEndpoint, modelConfiguration, request: { headers: { authorization: authHeader } }, - messages, + messageHistoryQuery, toolsConfiguration, }; @@ -850,6 +888,8 @@ export function response(ctx) { conversationId: ctx.args.conversationId, role: 'user', content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, createdAt: ctx.stash.defaultValues.createdAt, updatedAt: ctx.stash.defaultValues.updatedAt, }; diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts b/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts index 3ac99b9890..87d6a26c37 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts @@ -43,6 +43,11 @@ describe('ConversationTransformer', () => { const schema = parse(out.schema); validateModelSchema(schema); + + expect( + out.stacks.ConversationMessagePirateChat.Resources![`ListConversationMessage${toUpper(routeName)}Resolver`].Properties + .PipelineConfig.Functions, + ).toHaveLength(4); }); }); diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/schemas/conversation-schema-types.graphql b/packages/amplify-graphql-conversation-transformer/src/__tests__/schemas/conversation-schema-types.graphql index fac17f04e5..0bddee354a 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/schemas/conversation-schema-types.graphql +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/schemas/conversation-schema-types.graphql @@ -10,6 +10,7 @@ interface ConversationMessage { content: [ContentBlock] context: AWSJSON toolConfiguration: ToolConfiguration + associatedUserMessageId: ID } input DocumentBlockSourceInput { diff --git a/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts b/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts index 75950edd72..12eab5bf46 100644 --- a/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts +++ b/packages/amplify-graphql-conversation-transformer/src/graphql-types/message-model.ts @@ -227,12 +227,12 @@ const constructConversationMessageModel = ( const content = makeField('content', [], makeListType(makeNamedType('ContentBlock'))); const context = makeField('aiContext', [], makeNamedType('AWSJSON')); const uiComponents = makeField('toolConfiguration', [], makeNamedType('ToolConfiguration')); - const assistantContent = makeField('assistantContent', [], makeListType(makeNamedType('ContentBlock'))); + const associatedUserMessageId = makeField('associatedUserMessageId', [], makeNamedType('ID')); const object = { ...blankObject(modelName), interfaces: [conversationMessageInterface], - fields: [id, conversationId, sessionField, role, content, context, uiComponents, assistantContent], + fields: [id, conversationId, sessionField, role, content, context, uiComponents, associatedUserMessageId], directives: typeDirectives, }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js index d35ff74551..288b483083 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js @@ -1,4 +1,5 @@ import { util } from '@aws-appsync/utils'; +import * as ddb from '@aws-appsync/utils/dynamodb'; /** * Sends a request to the attached data source @@ -6,30 +7,23 @@ import { util } from '@aws-appsync/utils'; * @returns {*} the request */ export function request(ctx) { - const owner = ctx.identity['claims']['sub']; - ctx.stash.owner = owner; const { conversationId, content, associatedUserMessageId } = ctx.args.input; - const updatedAt = util.time.nowISO8601(); + const { owner } = ctx.args; + const defaultValues = ctx.stash.defaultValues ?? {}; + const id = defaultValues.id; - const expression = 'SET #assistantContent = :assistantContent, #updatedAt = :updatedAt'; - const expressionNames = { '#assistantContent': 'assistantContent', '#updatedAt': 'updatedAt' }; - const expressionValues = { ':assistantContent': content, ':updatedAt': updatedAt }; - const condition = JSON.parse( - util.transform.toDynamoDBConditionExpression({ - owner: { eq: owner }, - conversationId: { eq: conversationId }, - }), - ); - return { - operation: 'UpdateItem', - key: util.dynamodb.toMapValues({ id: associatedUserMessageId }), - condition, - update: { - expression, - expressionNames, - expressionValues: util.dynamodb.toMapValues(expressionValues), - }, + const message = { + __typename: '[[CONVERSATION_MESSAGE_TYPE_NAME]]', + id, + role: 'assistant', + content, + conversationId, + associatedUserMessageId, + owner, + ...defaultValues, }; + + return ddb.put({ key: { id }, item: message }); } /** @@ -43,16 +37,5 @@ export function response(ctx) { util.error(ctx.error.message, ctx.error.type); } - const { conversationId, content, associatedUserMessageId } = ctx.args.input; - const { createdAt, updatedAt } = ctx.result; - - return { - id: associatedUserMessageId, - content, - conversationId, - role: 'assistant', - owner: ctx.stash.owner, - createdAt, - updatedAt, - }; + return ctx.result; } diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts index ebfcf45d5e..0df9f8b473 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts @@ -3,6 +3,7 @@ import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interf import fs from 'fs'; import path from 'path'; import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; +import { toUpper } from 'graphql-transformer-common'; /** * Creates and returns the mapping template for the assistant mutation resolver. @@ -11,7 +12,15 @@ import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-tran * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ export const assistantMutationResolver = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { - const resolver = fs.readFileSync(path.join(__dirname, 'assistant-mutation-resolver-fn.template.js'), 'utf8'); + let resolver = fs.readFileSync(path.join(__dirname, 'assistant-mutation-resolver-fn.template.js'), 'utf8'); + const fieldName = toUpper(config.field.name.value); + const substitutions = { + CONVERSATION_MESSAGE_TYPE_NAME: `ConversationMessage${fieldName}`, + }; + Object.entries(substitutions).forEach(([key, value]) => { + const replaced = resolver.replace(new RegExp(`\\[\\[${key}\\]\\]`, 'g'), value); + resolver = replaced; + }); const templateName = `Mutation.${config.field.name.value}.assistant-response.js`; return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js index 00c8c6e19c..557bacbaf5 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js @@ -1,13 +1,12 @@ import { util } from '@aws-appsync/utils'; export function request(ctx) { - const { args, request, prev } = ctx; + const { args, request } = ctx; const { graphqlApiEndpoint } = ctx.stash; [[TOOL_DEFINITIONS_LINE]] const selectionSet = '[[SELECTION_SET]]'; - const messages = prev.result.items; const responseMutation = { name: '[[RESPONSE_MUTATION_NAME]]', inputTypeName: '[[RESPONSE_MUTATION_INPUT_TYPE_NAME]]', @@ -21,6 +20,14 @@ export function request(ctx) { }); [[TOOLS_CONFIGURATION_LINE]] + const messageHistoryQuery = { + getQueryName: '[[GET_QUERY_NAME]]', + getQueryInputTypeName: '[[GET_QUERY_INPUT_TYPE_NAME]]', + listQueryName: '[[LIST_QUERY_NAME]]', + listQueryInputTypeName: '[[LIST_QUERY_INPUT_TYPE_NAME]]', + listQueryLimit: [[LIST_QUERY_LIMIT]], + }; + const authHeader = request.headers['authorization']; const payload = { conversationId: args.conversationId, @@ -29,7 +36,7 @@ export function request(ctx) { graphqlApiEndpoint, modelConfiguration, request: { headers: { authorization: authHeader } }, - messages, + messageHistoryQuery, toolsConfiguration, }; @@ -50,6 +57,8 @@ export function response(ctx) { conversationId: ctx.args.conversationId, role: 'user', content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, createdAt: ctx.stash.defaultValues.createdAt, updatedAt: ctx.stash.defaultValues.updatedAt, }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts index 0fd3e9d16d..b73c3b4325 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts @@ -4,6 +4,8 @@ import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-tran import fs from 'fs'; import path from 'path'; import dedent from 'ts-dedent'; +import { toUpper } from 'graphql-transformer-common'; +import pluralize from 'pluralize'; /** * Creates a mapping template for invoking a Lambda function in the context of a GraphQL conversation. @@ -20,6 +22,14 @@ export const invokeLambdaMappingTemplate = (config: ConversationDirectiveConfigu const RESPONSE_MUTATION_INPUT_TYPE_NAME = config.responseMutationInputTypeName; const MESSAGE_MODEL_NAME = config.messageModel.messageModel.name.value; + // TODO: Create and add these values to `ConversationDirectiveConfiguration` in an earlier step and + // access them here. + const GET_QUERY_NAME = `getConversationMessage${toUpper(config.field.name.value)}`; + const GET_QUERY_INPUT_TYPE_NAME = 'ID'; + const LIST_QUERY_NAME = `listConversationMessage${toUpper(pluralize(config.field.name.value))}`; + const LIST_QUERY_INPUT_TYPE_NAME = `ModelConversationMessage${toUpper(config.field.name.value)}FilterInput`; + const LIST_QUERY_LIMIT = 'undefined'; + const substitutions = { TOOL_DEFINITIONS_LINE, TOOLS_CONFIGURATION_LINE, @@ -28,6 +38,11 @@ export const invokeLambdaMappingTemplate = (config: ConversationDirectiveConfigu RESPONSE_MUTATION_NAME, RESPONSE_MUTATION_INPUT_TYPE_NAME, MESSAGE_MODEL_NAME, + GET_QUERY_NAME, + GET_QUERY_INPUT_TYPE_NAME, + LIST_QUERY_NAME, + LIST_QUERY_INPUT_TYPE_NAME, + LIST_QUERY_LIMIT, }; let resolver = fs.readFileSync(path.join(__dirname, 'invoke-lambda-resolver-fn.template.js'), 'utf8'); diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.template.js new file mode 100644 index 0000000000..71a6777a2f --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.template.js @@ -0,0 +1,8 @@ +export function request(ctx) { + ctx.stash.metadata.index = 'gsi-ConversationMessage.conversationId.createdAt'; + return {}; +} + +export function response(ctx) { + return {}; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts new file mode 100644 index 0000000000..de1b608354 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts @@ -0,0 +1,16 @@ +import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; +import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; + +/** + * Creates and returns the function code for the list messages resolver init slot. + * + * @returns {MappingTemplateProvider} + */ +export const listMessageInitMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'list-messages-init-resolver-fn.template.js'), 'utf8'); + const templateName = `Query.${config.field.name.value}.list-messages-init.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); +}; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js deleted file mode 100644 index 82e7666ded..0000000000 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js +++ /dev/null @@ -1,40 +0,0 @@ -export function request(ctx) { - const { conversationId } = ctx.args; - const { authFilter } = ctx.stash; - - const limit = 100; - const query = { - expression: 'conversationId = :conversationId', - expressionValues: util.dynamodb.toMapValues({ - ':conversationId': conversationId, - }), - }; - - const filter = JSON.parse(util.transform.toDynamoDBFilterExpression(authFilter)); - const index = 'gsi-ConversationMessage.conversationId.createdAt'; - - return { - operation: 'Query', - query, - filter, - index, - scanIndexForward: false, - }; -} - -export function response(ctx) { - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } - const messagesWithAssistantResponse = ctx.result.items - .filter((message) => message.assistantContent !== undefined) - .reduce((acc, current) => { - acc.push({ role: 'user', content: current.content }); - acc.push({ role: 'assistant', content: current.assistantContent }); - return acc; - }, []); - - const currentMessage = { role: 'user', content: ctx.prev.result.content }; - const items = [...messagesWithAssistantResponse, currentMessage]; - return { items }; -} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts deleted file mode 100644 index ce56cfa558..0000000000 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; -import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import fs from 'fs'; -import path from 'path'; -import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; - -/** - * Creates a mapping template for reading message history in a conversation. - * - * @returns {MappingTemplateProvider} An object containing request and response mapping functions. - */ -export const readHistoryMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { - const resolver = fs.readFileSync(path.join(__dirname, 'message-history-resolver-fn.template.js'), 'utf8'); - const templateName = `Mutation.${config.field.name.value}.message-history.js`; - return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); -}; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js index 7a05ff10b7..25ce987425 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js @@ -1,10 +1,11 @@ export function request(ctx) { const { authFilter } = ctx.stash; + const { conversationId } = [[CONVERSATION_ID_PARENT]]; const query = { expression: 'id = :id', expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId, + ':id': conversationId, }), }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts index 0624e1c9d5..a718c580d8 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts @@ -9,8 +9,32 @@ import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-tran * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const verifySessionOwnerMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { - const resolver = fs.readFileSync(path.join(__dirname, 'verify-session-owner-resolver-fn.template.js'), 'utf8'); +export const verifySessionOwnerSendMessageMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const substitutions = { + CONVERSATION_ID_PARENT: 'ctx.args', + }; const templateName = `Mutation.${config.field.name.value}.verify-session-owner.js`; - return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); + return verifySessionOwnerMappingTemplate(templateName, substitutions); +}; + +/** + * Creates a mapping template for verifying the session owner in a conversation. + * + * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. + */ +export const verifySessionOwnerAssistantResponseMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const substitutions = { + CONVERSATION_ID_PARENT: 'ctx.args.input', + }; + const templateName = `Mutation.${config.field.name.value}AssistantResponse.verify-session-owner.js`; + return verifySessionOwnerMappingTemplate(templateName, substitutions); +}; + +const verifySessionOwnerMappingTemplate = (name: string, substitute: Record) => { + let resolver = fs.readFileSync(path.join(__dirname, 'verify-session-owner-resolver-fn.template.js'), 'utf8'); + Object.entries(substitute).forEach(([key, value]) => { + const replaced = resolver.replace(new RegExp(`\\[\\[${key}\\]\\]`, 'g'), value); + resolver = replaced; + }); + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, name); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts index cf9848dc01..c7fc24b484 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts @@ -2,15 +2,18 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; import fs from 'fs'; import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; +import { toUpper } from 'graphql-transformer-common'; /** * Creates a mapping template for writing a message to a table in a conversation. * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const writeMessageToTableMappingTemplate = (fieldName: string): MappingTemplateProvider => { +export const writeMessageToTableMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const fieldName = config.field.name.value; const substitutions = { - CONVERSATION_MESSAGE_TYPE_NAME: `ConversationMessage${fieldName}`, + CONVERSATION_MESSAGE_TYPE_NAME: `ConversationMessage${toUpper(fieldName)}`, }; let resolver = fs.readFileSync(path.join(__dirname, 'write-message-to-table-resolver-fn.template.js'), 'utf8'); Object.entries(substitutions).forEach(([key, value]) => { diff --git a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts index 1d272523c3..844595dd26 100644 --- a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts +++ b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts @@ -9,13 +9,17 @@ import { IFunction, Function } from 'aws-cdk-lib/aws-lambda'; import { getModelDataSourceNameForTypeName, getTable } from '@aws-amplify/graphql-transformer-core'; import { initMappingTemplate } from '../resolvers/init-resolver'; import { authMappingTemplate } from '../resolvers/auth-resolver'; -import { verifySessionOwnerMappingTemplate } from '../resolvers/verify-session-owner-resolver'; +import { + verifySessionOwnerSendMessageMappingTemplate, + verifySessionOwnerAssistantResponseMappingTemplate, +} from '../resolvers/verify-session-owner-resolver'; import { writeMessageToTableMappingTemplate } from '../resolvers/write-message-to-table-resolver'; -import { readHistoryMappingTemplate } from '../resolvers/message-history-resolver'; import { invokeLambdaMappingTemplate } from '../resolvers/invoke-lambda-resolver'; import { assistantMutationResolver } from '../resolvers/assistant-mutation-resolver'; import { conversationMessageSubscriptionMappingTamplate } from '../resolvers/assistant-messages-subscription-resolver'; import { overrideIndexAtCfnLevel } from '@aws-amplify/graphql-index-transformer'; +import pluralize from 'pluralize'; +import { listMessageInitMappingTemplate } from '../resolvers/list-messages-init-resolver'; type KeyAttributeDefinition = { name: string; @@ -28,6 +32,7 @@ export class ConversationResolverGenerator { for (const directive of directives) { this.processToolsForDirective(directive, ctx); this.generateResolversForDirective(directive, ctx); + this.addInitSlotToListMessagesPipeline(ctx, directive); } } @@ -46,14 +51,15 @@ export class ConversationResolverGenerator { const functionStack = this.createFunctionStack(ctx, capitalizedFieldName); const { functionDataSourceId, referencedFunction } = this.setupFunctionDataSource(directive, functionStack, capitalizedFieldName); - - this.createAssistantResponseResolver(ctx, directive, capitalizedFieldName); - this.createAssistantResponseSubscriptionResolver(ctx, directive, capitalizedFieldName); - const functionDataSource = this.addLambdaDataSource(ctx, functionDataSourceId, referencedFunction, capitalizedFieldName); const invokeLambdaFunction = invokeLambdaMappingTemplate(directive); this.setupMessageTableIndex(ctx, directive); + const initResolverFunction = initMappingTemplate(ctx); + const authResolverFunction = authMappingTemplate(directive); + const verifySessionOwnerSendMessageResolverFunction = verifySessionOwnerSendMessageMappingTemplate(directive); + const verifySessionOwnerAssistantResponseResolverFunction = verifySessionOwnerAssistantResponseMappingTemplate(directive); + const writeMessageToTableFunction = writeMessageToTableMappingTemplate(directive); this.createConversationPipelineResolver( ctx, @@ -62,8 +68,21 @@ export class ConversationResolverGenerator { capitalizedFieldName, functionDataSource, invokeLambdaFunction, + initResolverFunction, + authResolverFunction, + verifySessionOwnerSendMessageResolverFunction, + writeMessageToTableFunction, + ); + + this.createAssistantResponseResolver( + ctx, directive, + capitalizedFieldName, + initResolverFunction, + authResolverFunction, + verifySessionOwnerAssistantResponseResolverFunction, ); + this.createAssistantResponseSubscriptionResolver(ctx, directive, capitalizedFieldName); } /** @@ -169,7 +188,10 @@ export class ConversationResolverGenerator { capitalizedFieldName: string, functionDataSource: any, invokeLambdaFunction: MappingTemplateProvider, - directive: ConversationDirectiveConfiguration, + initResolverFunction: MappingTemplateProvider, + authResolverFunction: MappingTemplateProvider, + verifySessionOwnerResolverFunction: MappingTemplateProvider, + writeMessageToTableFunction: MappingTemplateProvider, ): void { const resolverResourceId = ResolverResourceIDs.ResolverResourceID(parentName, fieldName); const runtime = APPSYNC_JS_RUNTIME; @@ -178,13 +200,21 @@ export class ConversationResolverGenerator { fieldName, resolverResourceId, { codeMappingTemplate: invokeLambdaFunction }, - ['init', 'auth', 'verifySessionOwner', 'writeMessageToTable', 'retrieveMessageHistory'], + ['init', 'auth', 'verifySessionOwner', 'writeMessageToTable'], ['handleLambdaResponse', 'finish'], functionDataSource, runtime, ); - this.addPipelineResolverFunctions(ctx, conversationPipelineResolver, capitalizedFieldName, directive); + this.addPipelineResolverFunctions( + ctx, + conversationPipelineResolver, + capitalizedFieldName, + initResolverFunction, + authResolverFunction, + verifySessionOwnerResolverFunction, + writeMessageToTableFunction, + ); ctx.resolvers.addResolver(parentName, fieldName, conversationPipelineResolver); } @@ -200,33 +230,28 @@ export class ConversationResolverGenerator { ctx: TransformerContextProvider, resolver: TransformerResolver, capitalizedFieldName: string, - directive: ConversationDirectiveConfiguration, + initResolverFunction: MappingTemplateProvider, + authResolverFunction: MappingTemplateProvider, + verifySessionOwnerResolverFunction: MappingTemplateProvider, + writeMessageToTableFunction: MappingTemplateProvider, ): void { // Add init function - const initFunction = initMappingTemplate(ctx); - resolver.addJsFunctionToSlot('init', initFunction); + resolver.addJsFunctionToSlot('init', initResolverFunction); // Add auth function - const authFunction = authMappingTemplate(directive); - resolver.addJsFunctionToSlot('auth', authFunction); + resolver.addJsFunctionToSlot('auth', authResolverFunction); // Add verifySessionOwner function - const verifySessionOwnerFunction = verifySessionOwnerMappingTemplate(directive); const sessionModelName = `Conversation${capitalizedFieldName}`; const sessionModelDDBDataSourceName = getModelDataSourceNameForTypeName(ctx, sessionModelName); const conversationSessionDDBDataSource = ctx.api.host.getDataSource(sessionModelDDBDataSourceName); - resolver.addJsFunctionToSlot('verifySessionOwner', verifySessionOwnerFunction, conversationSessionDDBDataSource as any); + resolver.addJsFunctionToSlot('verifySessionOwner', verifySessionOwnerResolverFunction, conversationSessionDDBDataSource as any); // Add writeMessageToTable function - const writeMessageToTableFunction = writeMessageToTableMappingTemplate(directive.field.name.value); const messageModelName = `ConversationMessage${capitalizedFieldName}`; const messageModelDDBDataSourceName = getModelDataSourceNameForTypeName(ctx, messageModelName); const messageDDBDataSource = ctx.api.host.getDataSource(messageModelDDBDataSourceName); resolver.addJsFunctionToSlot('writeMessageToTable', writeMessageToTableFunction, messageDDBDataSource as any); - - // Add retrieveMessageHistory function - const retrieveMessageHistoryFunction = readHistoryMappingTemplate(directive); - resolver.addJsFunctionToSlot('retrieveMessageHistory', retrieveMessageHistoryFunction, messageDDBDataSource as any); } /** @@ -239,23 +264,38 @@ export class ConversationResolverGenerator { ctx: TransformerContextProvider, directive: ConversationDirectiveConfiguration, capitalizedFieldName: string, + initResolverFunction: MappingTemplateProvider, + authResolverFunction: MappingTemplateProvider, + verifySessionOwnerResolverFunction: MappingTemplateProvider, ): void { const assistantResponseResolverResourceId = ResolverResourceIDs.ResolverResourceID('Mutation', directive.responseMutationName); const assistantResponseResolverFunction = assistantMutationResolver(directive); const conversationMessageDataSourceName = getModelDataSourceNameForTypeName(ctx, `ConversationMessage${capitalizedFieldName}`); const conversationMessageDataSource = ctx.api.host.getDataSource(conversationMessageDataSourceName); - const assistantResponseResolver = new TransformerResolver( + const resolver = new TransformerResolver( 'Mutation', directive.responseMutationName, assistantResponseResolverResourceId, { codeMappingTemplate: assistantResponseResolverFunction }, - [], + ['init', 'auth', 'verifySessionOwner'], [], conversationMessageDataSource as any, APPSYNC_JS_RUNTIME, ); - ctx.resolvers.addResolver('Mutation', directive.responseMutationName, assistantResponseResolver); + // Add init function + resolver.addJsFunctionToSlot('init', initResolverFunction); + + // Add auth function + resolver.addJsFunctionToSlot('auth', authResolverFunction); + + // Add verifySessionOwner function + const sessionModelName = `Conversation${capitalizedFieldName}`; + const sessionModelDDBDataSourceName = getModelDataSourceNameForTypeName(ctx, sessionModelName); + const conversationSessionDDBDataSource = ctx.api.host.getDataSource(sessionModelDDBDataSourceName); + resolver.addJsFunctionToSlot('verifySessionOwner', verifySessionOwnerResolverFunction, conversationSessionDDBDataSource as any); + + ctx.resolvers.addResolver('Mutation', directive.responseMutationName, resolver); } /** @@ -313,6 +353,14 @@ export class ConversationResolverGenerator { return ctx.api.host.addLambdaDataSource(functionDataSourceId, referencedFunction, {}, functionDataSourceScope); } + private addInitSlotToListMessagesPipeline(ctx: TransformerContextProvider, directive: ConversationDirectiveConfiguration): void { + const messageModelName = directive.messageModel.messageModel.name.value; + const pluralized = pluralize(messageModelName); + const listMessagesResolver = ctx.resolvers.getResolver('Query', `list${pluralized}`) as TransformerResolver; + const initResolverFn = listMessageInitMappingTemplate(directive); + listMessagesResolver.addJsFunctionToSlot('init', initResolverFn); + } + /** * Sets up the message table index * @param ctx - The transformer context provider diff --git a/yarn.lock b/yarn.lock index a610c02ed5..029a17465b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10,10 +10,10 @@ "@jridgewell/gen-mapping" "^0.3.5" "@jridgewell/trace-mapping" "^0.3.24" -"@aws-amplify/ai-constructs@^0.1.4": - version "0.1.4" - resolved "https://registry.npmjs.org/@aws-amplify/ai-constructs/-/ai-constructs-0.1.4.tgz#043ca7793cb4a97ad7864797bd70dbfa323329f4" - integrity sha512-BGLBFs/pt6JrNgUo+QD0Szt/ssHMa6EyEE45yLoHemwPHRuJPpnFmxIbbxgxaqJP0mWK6QMs9Wh3IsdJ/6XhDA== +"@aws-amplify/ai-constructs@^0.2.0": + version "0.2.0" + resolved "https://registry.yarnpkg.com/@aws-amplify/ai-constructs/-/ai-constructs-0.2.0.tgz#91db9586d8e656a4ad7f2b0b539a2221a38124b2" + integrity sha512-aqmUrUvbWpebJcNCvoFywHLTQXNIlli8VE2i9+sSMlQXAG2zRiqcpDdRha+0NQnPNj09K2/DMLTe79ldwDaGkQ== dependencies: "@aws-amplify/plugin-types" "^1.0.1" "@aws-sdk/client-bedrock-runtime" "^3.622.0"