Skip to content

Commit

Permalink
Make Schema a discriminated union
Browse files Browse the repository at this point in the history
This leverages the type system to better describe the API's requirements for schemas. For example, rather than saying that any schema might have an optional `items` property, we're able to express that `items` is required on array schemas and forbidden on all others.

More info on discriminated unions: https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions
  • Loading branch information
rictic committed Nov 13, 2024
1 parent 6ec2c27 commit 14a8662
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 31 deletions.
5 changes: 5 additions & 0 deletions .changeset/young-rivers-shout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

The schema types are now more specific, using a [discriminated union](https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions) based on the 'type' field to more accurately define which fields are allowed.
22 changes: 11 additions & 11 deletions src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
FunctionCallingMode,
HarmBlockThreshold,
HarmCategory,
ObjectSchema,
SchemaType,
} from "../../types";
import { getMockResponse } from "../../test-utils/mock-response";
Expand Down Expand Up @@ -60,7 +61,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -93,7 +93,8 @@ describe("GenerativeModel", () => {
SchemaType.OBJECT,
);
expect(
genModel.generationConfig?.responseSchema.properties.testField.type,
(genModel.generationConfig?.responseSchema as ObjectSchema).properties
.testField.type,
).to.equal(SchemaType.STRING);
expect(genModel.generationConfig?.presencePenalty).to.equal(0.6);
expect(genModel.generationConfig?.frequencyPenalty).to.equal(0.5);
Expand Down Expand Up @@ -172,7 +173,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -206,7 +206,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -332,16 +331,17 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
Expand Down Expand Up @@ -372,7 +372,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand All @@ -381,8 +380,10 @@ describe("GenerativeModel", () => {
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -403,7 +404,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down
112 changes: 92 additions & 20 deletions types/function-calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,100 @@ export enum SchemaType {
* More fields may be added in the future as needed.
* @public
*/
export interface Schema {
/**
* Optional. The type of the property. {@link
* SchemaType}.
*/
type?: SchemaType;
/** Optional. The format of the property. */
format?: string;
/** Optional. The description of the property. */
export type Schema =
| StringSchema
| NumberSchema
| IntegerSchema
| BooleanSchema
| ArraySchema
| ObjectSchema;

/**
* Fields common to all Schema types.
*/
interface BaseSchema {
/** Optional. Description of the value. */
description?: string;
/** Optional. Whether the property is nullable. */
/** If true, the value can be null. */
nullable?: boolean;
/** Optional. The items of the property. */
items?: Schema;
/** Optional. The enum of the property. */

// The field 'example' is accepted, but in testing, it seems like it accepts
// any value of any type, even when that doesn't match the type that the
// schema describes, and it doesn't appear to affect the model's output.
}

/**
* Describes a JSON-encodable floating point number.
*/
export interface NumberSchema extends BaseSchema {
type: typeof SchemaType.NUMBER;
/** Optional. The format of the number. */
format?: "float" | "double";

// Note that the API accepts `minimum` and `maximum` fields here, as numbers,
// but when tested they had no effect.
}

/**
* Describes a JSON-encodable integer.
*/
export interface IntegerSchema extends BaseSchema {
type: typeof SchemaType.INTEGER;
/** Optional. The format of the number. */
format?: "int32" | "int64"; // server rejects int32 or int64

// Note that the API accepts minimum and maximum fields here, as numbers,
// but when tested they had no effect.
}

/**
* Describes a string.
*/
export interface StringSchema extends BaseSchema {
type: typeof SchemaType.STRING;
/** If present, limits the result to one of the given values. */
enum?: string[];
/** Optional. Map of {@link Schema}. */
properties?: { [k: string]: Schema };
/** Optional. Array of required property. */
// Note that the API accepts the `pattern`, `minLength`, and `maxLength`
// fields, but they may only be advisory.
// The `format` field is not (at time of writing) supported on strings.
}

/**
* Describes a boolean, either 'true' or 'false'.
*/
export interface BooleanSchema extends BaseSchema {
type: typeof SchemaType.BOOLEAN;
}

/**
* Describes an array, an ordered list of values.
*/
export interface ArraySchema extends BaseSchema {
type: typeof SchemaType.ARRAY;
/** A schema describing the entries in the array. */
items: Schema;

/** The minimum number of items in the array. */
minItems?: number;
/** The maximum number of items in the array. */
maxItems?: number;
}

/**
* Describes a JSON object, a mapping of specific keys to values.
*/
export interface ObjectSchema extends BaseSchema {
type: typeof SchemaType.OBJECT;
/** Describes the properties of the JSON object. Must not be empty. */
properties: { [k: string]: Schema };
/**
* A list of keys declared in the properties object.
* Required properties will always be present in the generated object.
*/
required?: string[];
/** Optional. The example of the property. */
example?: unknown;

// Note that the API accepts the `minProperties`, and `maxProperties` fields,
// but they may only be advisory.
}

/**
Expand All @@ -148,13 +220,13 @@ export interface FunctionDeclarationSchema {
* Schema for top-level function declaration
* @public
*/
export interface FunctionDeclarationSchemaProperty extends Schema {}
export type FunctionDeclarationSchemaProperty = Schema;

/**
* Schema passed to `GenerationConfig.responseSchema`
* @public
*/
export interface ResponseSchema extends Schema {}
export type ResponseSchema = Schema;

/**
* Tool config. This config is shared for all tools provided in the request.
Expand Down

0 comments on commit 14a8662

Please sign in to comment.