Skip to content

Commit

Permalink
feat(appsync-modelgen-plugin): add support for generation route defin…
Browse files Browse the repository at this point in the history
…itions
  • Loading branch information
atierian committed Aug 28, 2024
1 parent 1a5bf2d commit 3a359d3
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
},
"inputs": {
"$ref": "#/definitions/SchemaInputs"
},
"generations": {
"$ref": "#/definitions/SchemaQueries"
}
},
"required": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,75 @@ exports[`Custom queries/mutations/subscriptions & input type tests should genera
}"
`;

exports[`Generation Route Introspection Visitor Metadata snapshot should generate correct model intropection file validated by JSON schema 1`] = `
"{
\\"version\\": 1,
\\"models\\": {},
\\"enums\\": {},
\\"nonModels\\": {
\\"Recipe\\": {
\\"name\\": \\"Recipe\\",
\\"fields\\": {
\\"name\\": {
\\"name\\": \\"name\\",
\\"isArray\\": false,
\\"type\\": \\"String\\",
\\"isRequired\\": false,
\\"attributes\\": []
},
\\"ingredients\\": {
\\"name\\": \\"ingredients\\",
\\"isArray\\": true,
\\"type\\": \\"String\\",
\\"isRequired\\": false,
\\"attributes\\": [],
\\"isArrayNullable\\": true
},
\\"instructions\\": {
\\"name\\": \\"instructions\\",
\\"isArray\\": false,
\\"type\\": \\"String\\",
\\"isRequired\\": false,
\\"attributes\\": []
}
}
}
},
\\"generations\\": {
\\"generateRecipe\\": {
\\"name\\": \\"generateRecipe\\",
\\"isArray\\": false,
\\"type\\": {
\\"nonModel\\": \\"Recipe\\"
},
\\"isRequired\\": false,
\\"arguments\\": {
\\"description\\": {
\\"name\\": \\"description\\",
\\"isArray\\": false,
\\"type\\": \\"String\\",
\\"isRequired\\": false
}
}
},
\\"summarize\\": {
\\"name\\": \\"summarize\\",
\\"isArray\\": false,
\\"type\\": \\"String\\",
\\"isRequired\\": false,
\\"arguments\\": {
\\"text\\": {
\\"name\\": \\"text\\",
\\"isArray\\": false,
\\"type\\": \\"String\\",
\\"isRequired\\": false
}
}
}
}
}"
`;

exports[`Model Introspection Visitor Metadata snapshot should generate correct model intropection file validated by JSON schema 1`] = `
"{
\\"version\\": 1,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { buildSchema, GraphQLSchema, parse, visit } from 'graphql';
import { METADATA_SCALAR_MAP } from '../../scalars';
import { AppSyncDirectives, DefaultDirectives, V1Directives, DeprecatedDirective, Directive } from '@aws-amplify/graphql-directives';
import { AppSyncDirectives, DefaultDirectives, V1Directives, DeprecatedDirective, Directive, V2Directives } from '@aws-amplify/graphql-directives';
import { scalars } from '../../scalars/supported-scalars';
import { AppSyncModelIntrospectionVisitor } from '../../visitors/appsync-model-introspection-visitor';

Expand Down Expand Up @@ -806,3 +806,45 @@ describe('custom references', () => {
.toThrowError(`Error processing @belongsTo directive on SqlRelated.primary. @hasOne or @hasMany directive with references ["primaryId"] was not found in connected model SqlPrimary`);
});
});

describe('Generation Route Introspection Visitor', () => {
const schema = /* GraphQL */ `
type Recipe {
name: String
ingredients: [String]
instructions: String
}
type Query {
generateRecipe(description: String): Recipe
@generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "You are a recipe generator.")
summarize(text: String): String
@generation(aiModel: "anthropic.claude-3-haiku-20240307-v1:0", systemPrompt: "You are a text summarizer.")
}
`;

const generationDirective: Directive = {
name: 'generation',
definition: /* GraphQL */ `
directive @generation(
aiModel: String!
systemPrompt: String!
inferenceConfiguration: GenerationInferenceConfiguration
) on FIELD_DEFINITION
input GenerationInferenceConfiguration {
maxTokens: Int
temperature: Float
topP: Float
}
`,
defaults: {},
}
const visitor: AppSyncModelIntrospectionVisitor = getVisitor(schema, {}, [...V2Directives, generationDirective]);
describe('Metadata snapshot', () => {
it('should generate correct model intropection file validated by JSON schema', () => {
expect(visitor.generate()).toMatchSnapshot();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
mutations?: SchemaMutations;
subscriptions?: SchemaSubscriptions;
inputs?: SchemaInputs;
generations?: SchemaQueries;
};
/**
* Top-level Entities on a Schema
Expand Down
1 change: 1 addition & 0 deletions packages/appsync-modelgen-plugin/src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const TransformerV2DirectiveName = {
INDEX: 'index',
DEFAULT: 'default',
SEARCHABLE: 'searchable',
GENERATION: 'generation',
};
export const DEFAULT_HASH_KEY_FIELD = 'id';
export const DEFAULT_CREATED_TIME = 'createdAt';
Expand Down
6 changes: 5 additions & 1 deletion packages/appsync-modelgen-plugin/src/utils/fieldUtils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CodeGenDirective, CodeGenField, CodeGenModel } from '../visitors/appsync-visitor';
import { CodeGenDirective, CodeGenField, CodeGenModel, CodeGenQuery } from '../visitors/appsync-visitor';
import { TransformerV2DirectiveName } from './constants';

export function addFieldToModel(model: CodeGenModel, field: CodeGenField): void {
Expand Down Expand Up @@ -40,4 +40,8 @@ export function getModelPrimaryKeyComponentFields(model: CodeGenModel): CodeGenF
};
}
return keyFields;
}

export function containsGenerationDirective(queryField: CodeGenQuery): boolean {
return queryField.directives.some((directive) => directive.name === TransformerV2DirectiveName.GENERATION);
}
2 changes: 1 addition & 1 deletion packages/appsync-modelgen-plugin/src/validate-cjs.js

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Argument, AssociationType, Field, Fields, FieldType, ModelAttribute, Mo
import { METADATA_SCALAR_MAP } from "../scalars";
import { CodeGenConnectionType } from "../utils/process-connections";
import { RawAppSyncModelConfig, ParsedAppSyncModelConfig, AppSyncModelVisitor, CodeGenEnum, CodeGenField, CodeGenModel, CodeGenPrimaryKeyType, CodeGenQuery, CodeGenSubscription, CodeGenMutation, CodeGenInputObject } from "./appsync-visitor";
import { containsGenerationDirective } from "../utils/fieldUtils";

const validateModelIntrospectionSchema = require('../validate-cjs');

Expand Down Expand Up @@ -66,11 +67,17 @@ export class AppSyncModelIntrospectionVisitor<
// Skip the field if the field type is union/interface
// TODO: Remove this skip once these types are supported for stakeholder usages
const fieldType = this.getType(queryObj.type) as any;
if (this.isUnionFieldType(fieldType) || this.isInterfaceFieldType(fieldType)) {
if (this.isUnionFieldType(fieldType) || this.isInterfaceFieldType(fieldType) || containsGenerationDirective(queryObj)) {
return acc;
}
return { ...acc, [queryObj.name]: this.generateGraphQLOperationMetadata<CodeGenQuery, SchemaQuery>(queryObj) };
}, {})
const generations = Object.values(this.queryMap).reduce((acc, queryObj: CodeGenQuery) => {
if (!containsGenerationDirective(queryObj)) {
return acc;
}
return { ...acc, [queryObj.name]: this.generateGenerationMetadata(queryObj) };
}, {});
const mutations = Object.values(this.mutationMap).reduce((acc, mutationObj: CodeGenMutation) => {
// Skip the field if the field type is union/interface
// TODO: Remove this skip once these types are supported for stakeholder usages
Expand All @@ -95,6 +102,9 @@ export class AppSyncModelIntrospectionVisitor<
if (Object.keys(queries).length > 0) {
result = { ...result, queries };
}
if (Object.keys(generations).length > 0) {
result = { ...result, generations };
}
if (Object.keys(mutations).length > 0) {
result = { ...result, mutations };
}
Expand Down Expand Up @@ -236,6 +246,10 @@ export class AppSyncModelIntrospectionVisitor<
return operationMeta as V;
}

private generateGenerationMetadata(generationObj: CodeGenQuery): SchemaQuery {
return this.generateGraphQLOperationMetadata<CodeGenQuery, SchemaQuery>(generationObj);
}

protected getType(gqlType: string): FieldType | InputFieldType | UnionFieldType | InterfaceFieldType {
// Todo: Handle unlisted scalars
if (gqlType in METADATA_SCALAR_MAP) {
Expand Down

0 comments on commit 3a359d3

Please sign in to comment.