Skip to content

Commit

Permalink
Support new discriminator keys (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
samchungy authored Apr 20, 2024
1 parent 4219200 commit 3cf4b1b
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ For example in `z.string().nullable()` will be rendered differently
- `string` `type` mapping by default
- ZodDefault
- ZodDiscriminatedUnion
- `discriminator` mapping when all schemas in the union contain a `ref`.
- `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral` string value. Only `ZodLiteral` values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported.
- ZodEffects
- `transform` support for request schemas. See [Zod Effects](#zod-effects) for how to enable response schema support
- `pre-process` support. We assume that the input type is the same as the output type. Otherwise pipe and transform can be used instead.
Expand Down
181 changes: 181 additions & 0 deletions src/create/schema/parsers/discriminatedUnion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,185 @@ describe('createDiscriminatedUnionSchema', () => {

expect(result).toEqual(expected);
});

it('handles a discriminated union with an optional type', () => {
const expected: Schema = {
type: 'schema',
schema: {
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.literal('a').optional(),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with a nullable type', () => {
const expected: Schema = {
type: 'schema',
schema: {
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.literal('a').nullable(),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with a branded type', () => {
const expected: Schema = {
type: 'schema',
schema: {
discriminator: {
mapping: {
a: '#/components/schemas/a',
b: '#/components/schemas/b',
},
propertyName: 'type',
},
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.literal('a').brand(),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with a readonly type', () => {
const expected: Schema = {
type: 'schema',
schema: {
discriminator: {
mapping: {
a: '#/components/schemas/a',
b: '#/components/schemas/b',
},
propertyName: 'type',
},
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.literal('a').readonly(),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with a catch type', () => {
const expected: Schema = {
type: 'schema',
schema: {
discriminator: {
mapping: {
a: '#/components/schemas/a',
b: '#/components/schemas/b',
},
propertyName: 'type',
},
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.literal('a').catch('a'),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});
});
38 changes: 29 additions & 9 deletions src/create/schema/parsers/discriminatedUnion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ import type {
AnyZodObject,
ZodDiscriminatedUnion,
ZodDiscriminatedUnionOption,
ZodLiteralDef,
ZodRawShape,
ZodType,
ZodTypeAny,
} from 'zod';

import type { oas31 } from '../../../openapi3-ts/dist';
Expand Down Expand Up @@ -32,7 +33,6 @@ export const createDiscriminatedUnionSchema = <
schemaObjects,
options,
zodDiscriminatedUnion.discriminator,
state,
);
return {
type: 'schema',
Expand All @@ -44,11 +44,35 @@ export const createDiscriminatedUnionSchema = <
};
};

const unwrapLiteral = (
zodType: ZodType | ZodTypeAny | undefined,
): string | undefined => {
if (isZodType(zodType, 'ZodLiteral')) {
if (typeof zodType._def.value !== 'string') {
return undefined;
}
return zodType._def.value;
}

if (isZodType(zodType, 'ZodBranded')) {
return unwrapLiteral(zodType._def.type);
}

if (isZodType(zodType, 'ZodReadonly')) {
return unwrapLiteral(zodType._def.innerType);
}

if (isZodType(zodType, 'ZodCatch')) {
return unwrapLiteral(zodType._def.innerType);
}

return undefined;
};

export const mapDiscriminator = (
schemas: Array<oas31.SchemaObject | oas31.ReferenceObject>,
zodObjects: AnyZodObject[],
discriminator: unknown,
state: SchemaState,
): oas31.SchemaObject['discriminator'] => {
if (typeof discriminator !== 'string') {
return undefined;
Expand All @@ -71,14 +95,10 @@ export const mapDiscriminator = (
continue;
}

const literalValue = (value?._def as ZodLiteralDef<unknown>).value;
const literalValue = unwrapLiteral(value);

if (typeof literalValue !== 'string') {
throw new Error(
`Discriminator ${discriminator} could not be found in on index ${index} of a discriminated union at ${state.path.join(
' > ',
)}`,
);
return undefined;
}

mapping[literalValue] = componentSchemaRef;
Expand Down

0 comments on commit 3cf4b1b

Please sign in to comment.