From d9efa76a1650da35501b5a80803b3253faf92d7b Mon Sep 17 00:00:00 2001 From: Michael Hayes Date: Fri, 7 Mar 2025 17:12:44 -0800 Subject: [PATCH] implement input validation mapping logic --- packages/core/src/utils/index.ts | 4 +- .../plugin-add-graphql/src/schema-builder.ts | 4 +- .../src/drizzle-field-builder.ts | 4 +- .../plugin-prisma-utils/src/schema-builder.ts | 8 +- .../plugin-prisma/src/prisma-field-builder.ts | 2 +- packages/plugin-validation/src/index.ts | 439 +++++++++++++++--- .../tests/example/schema/index.ts | 31 +- 7 files changed, 404 insertions(+), 88 deletions(-) diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index af54b56e6..b555061eb 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -154,8 +154,8 @@ export function unwrapInputListParam( */ export function completeValue( valOrPromise: PromiseLike | T, - onSuccess: (completedVal: T) => R, - onError?: (errVal: unknown) => R, + onSuccess: (completedVal: T) => PromiseLike | R, + onError?: (errVal: unknown) => PromiseLike | R, ): Promise | R { if (isThenable(valOrPromise)) { return Promise.resolve(valOrPromise).then(onSuccess, onError); diff --git a/packages/plugin-add-graphql/src/schema-builder.ts b/packages/plugin-add-graphql/src/schema-builder.ts index c29b45b8d..2a4d8fb14 100644 --- a/packages/plugin-add-graphql/src/schema-builder.ts +++ b/packages/plugin-add-graphql/src/schema-builder.ts @@ -131,7 +131,7 @@ proto.addGraphQLObject = function addGraphQLObject( continue; } - const args: Record> = {}; + const args: Record> = {}; for (const { name, ...arg } of field.args) { const input = resolveInputType(this, arg.type); @@ -203,7 +203,7 @@ proto.addGraphQLInterface = function addGraphQLInterface( continue; } - const args: Record> = {}; + const args: Record> = {}; for (const { name, ...arg } of field.args) { args[name] = t.arg({ diff --git a/packages/plugin-drizzle/src/drizzle-field-builder.ts b/packages/plugin-drizzle/src/drizzle-field-builder.ts index 5cd51c07e..6b5f18be7 100644 --- a/packages/plugin-drizzle/src/drizzle-field-builder.ts +++ b/packages/plugin-drizzle/src/drizzle-field-builder.ts @@ -452,13 +452,13 @@ export class DrizzleObjectFieldBuilder< >, ] > - ) { + ): FieldRef, 'DrizzleObject'> { const [name, options = {} as never] = args; const typeConfig = this.builder.configStore.getTypeConfig(this.typename, 'Object'); const usingSelect = !!typeConfig.extensions?.pothosDrizzleSelect; - return this.exposeField(name as never, { + return this.exposeField(name as never, { ...options, extensions: { ...options.extensions, diff --git a/packages/plugin-prisma-utils/src/schema-builder.ts b/packages/plugin-prisma-utils/src/schema-builder.ts index b1c2ca494..fed278187 100644 --- a/packages/plugin-prisma-utils/src/schema-builder.ts +++ b/packages/plugin-prisma-utils/src/schema-builder.ts @@ -665,8 +665,8 @@ schemaBuilder.prismaUpdateRelation = function prismaUpdateRelation< data: dataType, } = fieldOption as { name?: string; - where: InputFieldRef | InputRef; - data: InputFieldRef | InputRef; + where: InputFieldRef | InputRef; + data: InputFieldRef | InputRef; }; const nestedRef = this.inputType(nestedName, { @@ -687,8 +687,8 @@ schemaBuilder.prismaUpdateRelation = function prismaUpdateRelation< skipDuplicates, } = fieldOption as { name?: string; - data: InputFieldRef | InputRef; - skipDuplicates?: InputFieldRef | InputRef; + data: InputFieldRef | InputRef; + skipDuplicates?: InputFieldRef | InputRef; }; const nestedRef = this.inputType(nestedName, { diff --git a/packages/plugin-prisma/src/prisma-field-builder.ts b/packages/plugin-prisma/src/prisma-field-builder.ts index 9bffe5267..4577649c4 100644 --- a/packages/plugin-prisma/src/prisma-field-builder.ts +++ b/packages/plugin-prisma/src/prisma-field-builder.ts @@ -614,7 +614,7 @@ export class PrismaObjectFieldBuilder< }, ] > - ) { + ): FieldRef, 'PrismaObject'> { const [options = {} as never] = args; const typeConfig = this.builder.configStore.getTypeConfig(this.typename); diff --git a/packages/plugin-validation/src/index.ts b/packages/plugin-validation/src/index.ts index 2e9dbfaad..5d69420dc 100644 --- a/packages/plugin-validation/src/index.ts +++ b/packages/plugin-validation/src/index.ts @@ -9,8 +9,18 @@ import SchemaBuilder, { RootFieldBuilder, type PothosInputFieldConfig, type SchemaTypes, + mapInputFields, + type PothosOutputFieldConfig, + type InputFieldsMapping, + type InputFieldMapping, + completeValue, + isThenable, + unwrapInputFieldType, + type MaybePromise, + PothosValidationError, } from '@pothos/core'; import type { StandardSchemaV1 } from './standard-schema'; +import type { GraphQLFieldResolver } from 'graphql'; export * from './types'; @@ -23,97 +33,386 @@ const pluginName = 'validation'; return args as never; }; -(FieldRef.prototype as FieldRef).validate = function validate() { +(FieldRef.prototype as FieldRef).validate = function validate(schema) { + this.updateConfig((config) => { + const extensions = (config.extensions ?? {}) as { validationSchemas?: StandardSchemaV1[] }; + + return { + ...config, + extensions: { + ...extensions, + validationSchemas: extensions.validationSchemas + ? [schema, ...extensions.validationSchemas] + : [schema], + }, + }; + }); return this as never; }; -(InputFieldRef.prototype as InputFieldRef).validate = function validate() { +(InputFieldRef.prototype as InputFieldRef).validate = function validate( + schema, +) { + this.updateConfig((config) => { + const extensions = (config.extensions ?? {}) as { validationSchemas?: StandardSchemaV1[] }; + + return { + ...config, + extensions: { + ...extensions, + validationSchemas: extensions.validationSchemas + ? [schema, ...extensions.validationSchemas] + : [schema], + }, + }; + }); return this as never; }; -(ArgumentRef.prototype as ArgumentRef).validate = function validate() { +(ArgumentRef.prototype as ArgumentRef).validate = function validate(schema) { + this.updateConfig((config) => { + const extensions = (config.extensions ?? {}) as { validationSchemas?: StandardSchemaV1[] }; + + return { + ...config, + extensions: { + ...extensions, + validationSchemas: extensions.validationSchemas + ? [schema, ...extensions.validationSchemas] + : [schema], + }, + }; + }); return this as never; }; -(InputObjectRef.prototype as InputObjectRef).validate = function validate() { +(InputObjectRef.prototype as InputObjectRef).validate = function validate( + schema, +) { + this.updateConfig((config) => { + const extensions = (config.extensions ?? {}) as { validationSchemas?: StandardSchemaV1[] }; + + return { + ...config, + extensions: { + ...extensions, + validationSchemas: extensions.validationSchemas + ? [schema, ...extensions.validationSchemas] + : [schema], + }, + }; + }); return this as never; }; +export class InputValidationError extends PothosValidationError { + issues: readonly StandardSchemaV1.Issue[]; + + constructor(issues: readonly StandardSchemaV1.Issue[]) { + super(issues.map((issue) => issue.message).join('\n')); + + this.issues = issues; + } +} + export class PothosZodPlugin extends BasePlugin { override onInputFieldConfig( fieldConfig: PothosInputFieldConfig, ): PothosInputFieldConfig { + if (fieldConfig.pothosOptions.validate) { + const extensions = (fieldConfig.extensions ?? {}) as { + validationSchemas?: StandardSchemaV1[]; + }; + + return { + ...fieldConfig, + extensions: { + ...extensions, + validationSchemas: [ + ...(extensions.validationSchemas ?? []), + fieldConfig.pothosOptions.validate, + ], + }, + }; + } + return fieldConfig; } - // override wrapResolve( - // resolver: GraphQLFieldResolver, - // fieldConfig: PothosOutputFieldConfig, - // ): GraphQLFieldResolver { - // // Only used to check if validation is required - // const argMap = mapInputFields( - // fieldConfig.args, - // this.buildCache, - // (field) => field.extensions?.validator ?? null, - // ); - - // if (!argMap && !fieldConfig.pothosOptions.validate) { - // return resolver; - // } - - // const args: Record> = {}; - - // for (const [argName, arg] of Object.entries(fieldConfig.args)) { - // const validator = arg.extensions?.validator as zod.ZodType | undefined; - - // if (validator) { - // args[argName] = validator; - // } - // } - - // let validator: zod.ZodTypeAny = zod.object(args).passthrough(); - - // if (fieldConfig.pothosOptions.validate) { - // validator = refine(validator, fieldConfig.pothosOptions.validate as ValidationOptionUnion); - // } - - // const validationError = this.builder.options.zod?.validationError; - - // const validatorWithErrorHandling R extends any - // validationError && - // async function validate(value: unknown, ctx: object, info: GraphQLResolveInfo) { - // try { - // const result: unknown = await validator.parseAsync(value); - - // return result; - // } catch (error: unknown) { - // const errorOrMessage = validationError( - // error as zod.ZodError, - // value as Record, - // ctx, - // info, - // ); - - // if (typeof errorOrMessage === 'string') { - // throw new PothosValidationError(errorOrMessage); - // } - - // throw errorOrMessage; - // } - // }; - - // return async (parent, rawArgs, context, info) => - // resolver( - // parent, - // (await (validatorWithErrorHandling - // ? validatorWithErrorHandling(rawArgs, context, info) - // : validator.parseAsync(rawArgs))) as object, - // context, - // info, - // ); - // } + override onOutputFieldConfig( + fieldConfig: PothosOutputFieldConfig, + ): PothosOutputFieldConfig | null { + if (fieldConfig.pothosOptions.validate) { + const extensions = (fieldConfig.extensions ?? {}) as { + validationSchemas?: StandardSchemaV1[]; + }; + + return { + ...fieldConfig, + extensions: { + ...extensions, + validationSchemas: [ + ...(extensions.validationSchemas ?? []), + fieldConfig.pothosOptions.validate, + ], + }, + }; + } + + return fieldConfig; + } + + override wrapResolve( + resolver: GraphQLFieldResolver, + fieldConfig: PothosOutputFieldConfig, + ): GraphQLFieldResolver { + // Only used to check if validation is required + const argMappings = mapInputFields(fieldConfig.args, this.buildCache, (field) => { + const fieldSchemas = (field.extensions?.validationSchemas as StandardSchemaV1[]) ?? null; + const fieldTypeName = unwrapInputFieldType(field.type); + const typeSchemas = + (this.buildCache.getTypeConfig(fieldTypeName).extensions + ?.validationSchemas as StandardSchemaV1[]) ?? null; + + return fieldSchemas || typeSchemas + ? { + fieldSchemas, + typeSchemas, + } + : null; + }); + + const argsSchemas = fieldConfig.extensions?.validationSchemas as StandardSchemaV1[] | null; + + if (!argMappings && !argsSchemas) { + return resolver; + } + + const argValidator = createArgsValidator(argMappings, argsSchemas); + + return async (parent, rawArgs, context, info) => + completeValue(argValidator(rawArgs), (validated) => { + return resolver(parent, validated as object, context, info); + }); + } } SchemaBuilder.registerPlugin(pluginName, PothosZodPlugin); export default pluginName; + +function createArgsValidator( + argMappings: InputFieldsMapping< + Types, + { + typeSchemas: StandardSchemaV1[]; + fieldSchemas: StandardSchemaV1[]; + } + > | null, + argsSchemas: StandardSchemaV1[] | null, +) { + const argMapper = argMappings + ? createInputValueMapper(argMappings, (value, mappings, addIssues) => { + const { typeSchemas, fieldSchemas } = mappings.value!; + const mapped = typeSchemas + ? reduceMaybeAsync(typeSchemas, value, (val, schema) => + completeValue(schema['~standard'].validate(val), (result) => { + if (result.issues) { + addIssues(result.issues); + return null; + } + + return result.value; + }), + ) + : value; + + if (mapped === null) { + return value; + } + + if (fieldSchemas) { + return reduceMaybeAsync(fieldSchemas, mapped, (val, schema) => + completeValue(schema['~standard'].validate(val), (result) => { + if (result.issues) { + addIssues(result.issues); + return null; + } + + return result.value; + }), + ); + } + + return mapped; + }) + : null; + + return function validateArgs(args: object) { + return completeValue( + argMapper ? argMapper(args) : { value: args, issues: undefined }, + (mapped) => { + if (mapped.issues) { + throw new InputValidationError(mapped.issues); + } + + if (!argsSchemas) { + return mapped.value; + } + + const issues: StandardSchemaV1.Issue[] = []; + + const validated = reduceMaybeAsync(argsSchemas, mapped.value, (val, schema) => + completeValue(schema['~standard'].validate(val), (result) => { + if (result.issues) { + issues.push(...result.issues); + return null; + } + + return result.value; + }), + ); + + return completeValue(validated, (result) => { + if (issues.length) { + throw new InputValidationError(issues); + } + + return result; + }); + }, + ); + }; +} + +function createInputValueMapper( + argMap: InputFieldsMapping, + mapValue: ( + val: unknown, + mapping: InputFieldMapping, + addIssues: (issues: readonly StandardSchemaV1.Issue[]) => void, + ...args: Args + ) => unknown, +) { + return function mapObject( + obj: object, + map: InputFieldsMapping = argMap, + path: (string | number)[] = [], + ...args: Args + ): MaybePromise> { + const mapped: Record = { ...obj }; + const issues: StandardSchemaV1.Issue[] = []; + + function addIssues(path: (string | number)[]) { + return (newIssues: readonly StandardSchemaV1.Issue[]) => { + issues.push( + ...newIssues.map((issue) => ({ ...issue, path: [...path, ...(issue.path ?? [])] })), + ); + }; + } + + const promises: Promise[] = []; + + map.forEach((field, fieldName) => { + const fieldVal = (obj as Record)[fieldName]; + const fieldPromises: Promise[] = []; + + if (fieldVal === null || fieldVal === undefined) { + mapped[fieldName] = fieldVal; + return; + } + + if (field.kind === 'InputObject' && field.fields.map) { + if (field.isList) { + const newList = [...(fieldVal as unknown[])]; + mapped[fieldName] = newList; + + (fieldVal as (Record | null)[]).map((val, i) => { + if (val) { + const promise = completeValue( + mapObject(val, field.fields.map!, [...path, fieldName, i], ...args), + (newVal) => { + if (newVal.issues) { + issues.push(...newVal.issues); + } else { + newList[i] = newVal.value; + } + }, + ); + + if (isThenable(promise)) { + fieldPromises.push(promise); + } + } + }); + } else { + const promise = completeValue( + mapObject( + fieldVal as Record, + field.fields.map, + [...path, fieldName], + ...args, + ), + (newVal) => { + if (newVal.issues) { + issues.push(...newVal.issues); + } else { + mapped[fieldName] = newVal.value; + } + }, + ); + + if (isThenable(promise)) { + fieldPromises.push(promise); + } + } + } + + const promise = completeValue( + fieldPromises.length ? Promise.all(fieldPromises) : null, + () => { + if (field.value !== null && !issues.length) { + return completeValue( + mapValue(mapped[fieldName], field, addIssues([...path, fieldName]), ...args), + (newVal) => { + mapped[fieldName] = newVal; + }, + ); + } + }, + ); + + if (isThenable(promise)) { + promises.push(promise); + } + }); + + return completeValue(promises.length ? Promise.all(promises) : null, () => { + return issues.length + ? { + issues, + } + : { + value: mapped, + issues: undefined, + }; + }); + }; +} +function reduceMaybeAsync( + items: T[], + initialValue: R, + fn: (value: R, item: T, i: number) => MaybePromise, +) { + function next(value: R, i: number): MaybePromise { + if (i === items.length) { + return value; + } + + return completeValue(fn(value, items[i], i), (result) => { + return result === null ? null : next(result, i + 1); + }); + } + + return next(initialValue, 0); +} diff --git a/packages/plugin-validation/tests/example/schema/index.ts b/packages/plugin-validation/tests/example/schema/index.ts index e21657e59..976efe7e2 100644 --- a/packages/plugin-validation/tests/example/schema/index.ts +++ b/packages/plugin-validation/tests/example/schema/index.ts @@ -369,25 +369,42 @@ import builder from '../builder'; // }), // ); -const someSchema = zod.unknown(); +const stringSchema = zod.string().min(2).startsWith('A'); +const stringSchema2 = zod + .string() + .startsWith('a') + .transform(async (val) => val.toUpperCase()); + +const inputSchema = zod.object({ + someField: zod.string().min(2).startsWith('A'), +}); + +const argSchema = zod.object({ + someArg: zod.string().min(2).startsWith('A'), + someInput: inputSchema, + other: zod.string().min(2).startsWith('A'), +}); // Schema for inputs const SomeInput = builder.inputType('SomeInput', { - validate: someSchema, + validate: inputSchema, fields: (t) => ({ - someField: t.string({ validate: someSchema }).validate(someSchema), + someField: t.string({ validate: stringSchema }).validate(stringSchema2), }), }); builder.queryFields((t) => ({ someQuery: t.string({ args: { - someArg: t.arg.string({ validate: someSchema }).validate(someSchema), - someInput: t.arg({ type: SomeInput }).validate(someSchema), + someArg: t.arg.string({ validate: stringSchema }).validate(stringSchema2), + someInput: t.arg({ type: SomeInput }).validate(inputSchema), + other: t.arg.string(), }, - validate: someSchema, - resolve: () => 'result', + validate: argSchema, + resolve: (_, args) => JSON.stringify(args, null, 2), }), })); +builder.queryType({}); + export default builder.toSchema();