diff --git a/README.md b/README.md index 6552bad..c5367c9 100644 --- a/README.md +++ b/README.md @@ -512,7 +512,7 @@ For example in `z.string().nullable()` will be rendered differently - `string` `type` mapping by default - ZodDefault - ZodDiscriminatedUnion - - `discriminator` mapping when all schemas in the union contain a `ref`. + - `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral` string value. Only `ZodLiteral` values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported. - ZodEffects - `transform` support for request schemas. See [Zod Effects](#zod-effects) for how to enable response schema support - `pre-process` support. We assume that the input type is the same as the output type. Otherwise pipe and transform can be used instead. diff --git a/src/create/schema/parsers/discriminatedUnion.test.ts b/src/create/schema/parsers/discriminatedUnion.test.ts index 5201b31..b13d9bd 100644 --- a/src/create/schema/parsers/discriminatedUnion.test.ts +++ b/src/create/schema/parsers/discriminatedUnion.test.ts @@ -177,4 +177,185 @@ describe('createDiscriminatedUnionSchema', () => { expect(result).toEqual(expected); }); + + it('handles a discriminated union with an optional type', () => { + const expected: Schema = { + type: 'schema', + schema: { + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.literal('a').optional(), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + + it('handles a discriminated union with a nullable type', () => { + const expected: Schema = { + type: 'schema', + schema: { + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.literal('a').nullable(), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + + it('handles a discriminated union with a branded type', () => { + const expected: Schema = { + type: 'schema', + schema: { + discriminator: { + mapping: { + a: '#/components/schemas/a', + b: '#/components/schemas/b', + }, + propertyName: 'type', + }, + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.literal('a').brand(), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + + it('handles a discriminated union with a readonly type', () => { + const expected: Schema = { + type: 'schema', + schema: { + discriminator: { + mapping: { + a: '#/components/schemas/a', + b: '#/components/schemas/b', + }, + propertyName: 'type', + }, + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.literal('a').readonly(), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + + it('handles a discriminated union with a catch type', () => { + const expected: Schema = { + type: 'schema', + schema: { + discriminator: { + mapping: { + a: '#/components/schemas/a', + b: '#/components/schemas/b', + }, + propertyName: 'type', + }, + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.literal('a').catch('a'), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); }); diff --git a/src/create/schema/parsers/discriminatedUnion.ts b/src/create/schema/parsers/discriminatedUnion.ts index 6a3f969..dee360f 100644 --- a/src/create/schema/parsers/discriminatedUnion.ts +++ b/src/create/schema/parsers/discriminatedUnion.ts @@ -2,8 +2,9 @@ import type { AnyZodObject, ZodDiscriminatedUnion, ZodDiscriminatedUnionOption, - ZodLiteralDef, ZodRawShape, + ZodType, + ZodTypeAny, } from 'zod'; import type { oas31 } from '../../../openapi3-ts/dist'; @@ -32,7 +33,6 @@ export const createDiscriminatedUnionSchema = < schemaObjects, options, zodDiscriminatedUnion.discriminator, - state, ); return { type: 'schema', @@ -44,11 +44,35 @@ export const createDiscriminatedUnionSchema = < }; }; +const unwrapLiteral = ( + zodType: ZodType | ZodTypeAny | undefined, +): string | undefined => { + if (isZodType(zodType, 'ZodLiteral')) { + if (typeof zodType._def.value !== 'string') { + return undefined; + } + return zodType._def.value; + } + + if (isZodType(zodType, 'ZodBranded')) { + return unwrapLiteral(zodType._def.type); + } + + if (isZodType(zodType, 'ZodReadonly')) { + return unwrapLiteral(zodType._def.innerType); + } + + if (isZodType(zodType, 'ZodCatch')) { + return unwrapLiteral(zodType._def.innerType); + } + + return undefined; +}; + export const mapDiscriminator = ( schemas: Array, zodObjects: AnyZodObject[], discriminator: unknown, - state: SchemaState, ): oas31.SchemaObject['discriminator'] => { if (typeof discriminator !== 'string') { return undefined; @@ -71,14 +95,10 @@ export const mapDiscriminator = ( continue; } - const literalValue = (value?._def as ZodLiteralDef).value; + const literalValue = unwrapLiteral(value); if (typeof literalValue !== 'string') { - throw new Error( - `Discriminator ${discriminator} could not be found in on index ${index} of a discriminated union at ${state.path.join( - ' > ', - )}`, - ); + return undefined; } mapping[literalValue] = componentSchemaRef;