diff --git a/README.md b/README.md index c5367c9..474615d 100644 --- a/README.md +++ b/README.md @@ -512,7 +512,7 @@ For example in `z.string().nullable()` will be rendered differently - `string` `type` mapping by default - ZodDefault - ZodDiscriminatedUnion - - `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral` string value. Only `ZodLiteral` values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported. + - `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral`, `ZodEnum` or `ZodNativeEnum` with string values. Only values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported. - ZodEffects - `transform` support for request schemas. See [Zod Effects](#zod-effects) for how to enable response schema support - `pre-process` support. We assume that the input type is the same as the output type. Otherwise pipe and transform can be used instead. diff --git a/src/create/schema/parsers/discriminatedUnion.test.ts b/src/create/schema/parsers/discriminatedUnion.test.ts index b13d9bd..7eab258 100644 --- a/src/create/schema/parsers/discriminatedUnion.test.ts +++ b/src/create/schema/parsers/discriminatedUnion.test.ts @@ -178,6 +178,89 @@ describe('createDiscriminatedUnionSchema', () => { expect(result).toEqual(expected); }); + it('creates a oneOf schema with discriminator mapping when schemas with string nativeEnums', () => { + const expected: Schema = { + type: 'schema', + schema: { + discriminator: { + mapping: { + a: '#/components/schemas/a', + c: '#/components/schemas/a', + b: '#/components/schemas/b', + }, + propertyName: 'type', + }, + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + enum letters { + a = 'a', + c = 'c', + } + + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.nativeEnum(letters), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + + it('creates a oneOf schema without discriminator mapping when schemas with mixed nativeEnums', () => { + const expected: Schema = { + type: 'schema', + schema: { + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + enum mixed { + a = 'a', + c = 'c', + d = 1, + } + + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.nativeEnum(mixed), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + it('handles a discriminated union with an optional type', () => { const expected: Schema = { type: 'schema', @@ -281,6 +364,46 @@ describe('createDiscriminatedUnionSchema', () => { expect(result).toEqual(expected); }); + it('handles a discriminated union with a branded enum type', () => { + const expected: Schema = { + type: 'schema', + schema: { + discriminator: { + mapping: { + a: '#/components/schemas/a', + c: '#/components/schemas/a', + b: '#/components/schemas/b', + }, + propertyName: 'type', + }, + oneOf: [ + { + $ref: '#/components/schemas/a', + }, + { + $ref: '#/components/schemas/b', + }, + ], + }, + }; + const schema = z.discriminatedUnion('type', [ + z + .object({ + type: z.enum(['a', 'c']).brand(), + }) + .openapi({ ref: 'a' }), + z + .object({ + type: z.literal('b'), + }) + .openapi({ ref: 'b' }), + ]); + + const result = createDiscriminatedUnionSchema(schema, createOutputState()); + + expect(result).toEqual(expected); + }); + it('handles a discriminated union with a readonly type', () => { const expected: Schema = { type: 'schema', diff --git a/src/create/schema/parsers/discriminatedUnion.ts b/src/create/schema/parsers/discriminatedUnion.ts index dee360f..6861dfc 100644 --- a/src/create/schema/parsers/discriminatedUnion.ts +++ b/src/create/schema/parsers/discriminatedUnion.ts @@ -15,6 +15,7 @@ import { createSchemaObject, } from '../../schema'; +import { createNativeEnumSchema } from './nativeEnum'; import { flattenEffects } from './transform'; export const createDiscriminatedUnionSchema = < @@ -33,6 +34,7 @@ export const createDiscriminatedUnionSchema = < schemaObjects, options, zodDiscriminatedUnion.discriminator, + state, ); return { type: 'schema', @@ -44,26 +46,38 @@ export const createDiscriminatedUnionSchema = < }; }; -const unwrapLiteral = ( +const unwrapLiterals = ( zodType: ZodType | ZodTypeAny | undefined, -): string | undefined => { + state: SchemaState, +): string[] | undefined => { if (isZodType(zodType, 'ZodLiteral')) { if (typeof zodType._def.value !== 'string') { return undefined; } - return zodType._def.value; + return [zodType._def.value]; + } + + if (isZodType(zodType, 'ZodNativeEnum')) { + const schema = createNativeEnumSchema(zodType, state); + if (schema.type === 'schema' && schema.schema.type === 'string') { + return schema.schema.enum; + } + } + + if (isZodType(zodType, 'ZodEnum')) { + return zodType._def.values; } if (isZodType(zodType, 'ZodBranded')) { - return unwrapLiteral(zodType._def.type); + return unwrapLiterals(zodType._def.type, state); } if (isZodType(zodType, 'ZodReadonly')) { - return unwrapLiteral(zodType._def.innerType); + return unwrapLiterals(zodType._def.innerType, state); } if (isZodType(zodType, 'ZodCatch')) { - return unwrapLiteral(zodType._def.innerType); + return unwrapLiterals(zodType._def.innerType, state); } return undefined; @@ -73,6 +87,7 @@ export const mapDiscriminator = ( schemas: Array, zodObjects: AnyZodObject[], discriminator: unknown, + state: SchemaState, ): oas31.SchemaObject['discriminator'] => { if (typeof discriminator !== 'string') { return undefined; @@ -88,20 +103,15 @@ export const mapDiscriminator = ( const value = (zodObject.shape as ZodRawShape)[discriminator]; - if (isZodType(value, 'ZodEnum')) { - for (const enumValue of value._def.values as string[]) { - mapping[enumValue] = componentSchemaRef; - } - continue; - } + const literals = unwrapLiterals(value, state); - const literalValue = unwrapLiteral(value); - - if (typeof literalValue !== 'string') { + if (!literals) { return undefined; } - mapping[literalValue] = componentSchemaRef; + for (const enumValue of literals) { + mapping[enumValue] = componentSchemaRef; + } } return {