From f7d75d5b8e2342b6ddaf347ee485b444cf2e8919 Mon Sep 17 00:00:00 2001 From: Peter Burns Date: Tue, 12 Nov 2024 20:08:11 -0800 Subject: [PATCH] Make Schema a discriminated union This leverages the type system to better describe the API's requirements for schemas. For example, rather than saying that any schema might have an optional `items` property, we're able to express that `items` is required on array schemas and forbidden on all others. More info on discriminated unions: https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions --- .changeset/young-rivers-shout.md | 5 + common/api-review/generative-ai-server.api.md | 72 ++++++++--- common/api-review/generative-ai.api.md | 72 ++++++++--- rollup.config.mjs | 9 +- samples/web/README.md | 2 +- samples/web/utils/shared.js | 2 +- src/models/generative-model.test.ts | 22 ++-- types/function-calling.ts | 112 ++++++++++++++---- 8 files changed, 228 insertions(+), 68 deletions(-) create mode 100644 .changeset/young-rivers-shout.md diff --git a/.changeset/young-rivers-shout.md b/.changeset/young-rivers-shout.md new file mode 100644 index 00000000..ceeb85b7 --- /dev/null +++ b/.changeset/young-rivers-shout.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": minor +--- + +The schema types are now more specific, using a [discriminated union](https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions) based on the 'type' field to more accurately define which fields are allowed. diff --git a/common/api-review/generative-ai-server.api.md b/common/api-review/generative-ai-server.api.md index 4265a3ab..bebd633d 100644 --- a/common/api-review/generative-ai-server.api.md +++ b/common/api-review/generative-ai-server.api.md @@ -4,6 +4,27 @@ ```ts +// @public +export interface ArraySchema extends BaseSchema { + items: Schema; + maxItems?: number; + minItems?: number; + // (undocumented) + type: typeof SchemaType.ARRAY; +} + +// @public +export interface BaseSchema { + description?: string; + nullable?: boolean; +} + +// @public +export interface BooleanSchema extends BaseSchema { + // (undocumented) + type: typeof SchemaType.BOOLEAN; +} + // @public export interface CachedContent extends CachedContentBase { createTime?: string; @@ -286,8 +307,7 @@ export interface FunctionDeclarationSchema { } // @public -export interface FunctionDeclarationSchemaProperty extends Schema { -} +export type FunctionDeclarationSchemaProperty = Schema; // @public export interface FunctionDeclarationsTool { @@ -368,6 +388,13 @@ export interface InlineDataPart { text?: never; } +// @public +export interface IntegerSchema extends BaseSchema { + format?: "int32" | "int64"; + // (undocumented) + type: typeof SchemaType.INTEGER; +} + // @public (undocumented) export interface ListCacheResponse { // (undocumented) @@ -392,6 +419,23 @@ export interface ListParams { pageToken?: string; } +// @public +export interface NumberSchema extends BaseSchema { + format?: "float" | "double"; + // (undocumented) + type: typeof SchemaType.NUMBER; +} + +// @public +export interface ObjectSchema extends BaseSchema { + properties: { + [k: string]: Schema; + }; + required?: string[]; + // (undocumented) + type: typeof SchemaType.OBJECT; +} + // @public export enum Outcome { OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded", @@ -413,8 +457,7 @@ export interface RequestOptions { } // @public -export interface ResponseSchema extends Schema { -} +export type ResponseSchema = Schema; // @public export interface RpcStatus { @@ -424,19 +467,7 @@ export interface RpcStatus { } // @public -export interface Schema { - description?: string; - enum?: string[]; - example?: unknown; - format?: string; - items?: Schema; - nullable?: boolean; - properties?: { - [k: string]: Schema; - }; - required?: string[]; - type?: SchemaType; -} +export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema; // @public export enum SchemaType { @@ -453,6 +484,13 @@ export interface SingleRequestOptions extends RequestOptions { signal?: AbortSignal; } +// @public +export interface StringSchema extends BaseSchema { + enum?: string[]; + // (undocumented) + type: typeof SchemaType.STRING; +} + // @public export interface TextPart { // (undocumented) diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md index c06c2224..1f225991 100644 --- a/common/api-review/generative-ai.api.md +++ b/common/api-review/generative-ai.api.md @@ -4,6 +4,15 @@ ```ts +// @public +export interface ArraySchema extends BaseSchema { + items: Schema; + maxItems?: number; + minItems?: number; + // (undocumented) + type: typeof SchemaType.ARRAY; +} + // @public export interface BaseParams { // (undocumented) @@ -12,6 +21,12 @@ export interface BaseParams { safetySettings?: SafetySetting[]; } +// @public +export interface BaseSchema { + description?: string; + nullable?: boolean; +} + // @public export interface BatchEmbedContentsRequest { // (undocumented) @@ -34,6 +49,12 @@ export enum BlockReason { SAFETY = "SAFETY" } +// @public +export interface BooleanSchema extends BaseSchema { + // (undocumented) + type: typeof SchemaType.BOOLEAN; +} + // @public export interface CachedContent extends CachedContentBase { createTime?: string; @@ -355,8 +376,7 @@ export interface FunctionDeclarationSchema { } // @public -export interface FunctionDeclarationSchemaProperty extends Schema { -} +export type FunctionDeclarationSchemaProperty = Schema; // @public export interface FunctionDeclarationsTool { @@ -645,6 +665,13 @@ export interface InlineDataPart { text?: never; } +// @public +export interface IntegerSchema extends BaseSchema { + format?: "int32" | "int64"; + // (undocumented) + type: typeof SchemaType.INTEGER; +} + // @public export interface LogprobsCandidate { logProbability: number; @@ -672,6 +699,23 @@ export interface ModelParams extends BaseParams { tools?: Tool[]; } +// @public +export interface NumberSchema extends BaseSchema { + format?: "float" | "double"; + // (undocumented) + type: typeof SchemaType.NUMBER; +} + +// @public +export interface ObjectSchema extends BaseSchema { + properties: { + [k: string]: Schema; + }; + required?: string[]; + // (undocumented) + type: typeof SchemaType.OBJECT; +} + // @public export enum Outcome { OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded", @@ -706,8 +750,7 @@ export interface RequestOptions { } // @public -export interface ResponseSchema extends Schema { -} +export type ResponseSchema = Schema; // @public export interface RetrievalMetadata { @@ -731,19 +774,7 @@ export interface SafetySetting { } // @public -export interface Schema { - description?: string; - enum?: string[]; - example?: unknown; - format?: string; - items?: Schema; - nullable?: boolean; - properties?: { - [k: string]: Schema; - }; - required?: string[]; - type?: SchemaType; -} +export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema; // @public export enum SchemaType { @@ -779,6 +810,13 @@ export interface StartChatParams extends BaseParams { tools?: Tool[]; } +// @public +export interface StringSchema extends BaseSchema { + enum?: string[]; + // (undocumented) + type: typeof SchemaType.STRING; +} + // @public export enum TaskType { // (undocumented) diff --git a/rollup.config.mjs b/rollup.config.mjs index 13e1efb9..b543aed5 100644 --- a/rollup.config.mjs +++ b/rollup.config.mjs @@ -19,7 +19,14 @@ import replace from "rollup-plugin-replace"; import typescriptPlugin from "rollup-plugin-typescript2"; import typescript from "typescript"; import json from "@rollup/plugin-json"; -import pkg from "./package.json" assert { type: "json" }; +import * as fs from "node:fs"; +import * as path from "node:path"; +import { fileURLToPath } from "node:url"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + +const pkg = JSON.parse(fs.readFileSync(path.join(__dirname, "package.json"))); const es2017BuildPlugins = [ typescriptPlugin({ diff --git a/samples/web/README.md b/samples/web/README.md index 9149689a..2861629a 100644 --- a/samples/web/README.md +++ b/samples/web/README.md @@ -3,7 +3,7 @@ This sample app demonstrates how to use state-of-the-art generative AI models (like Gemini) to build AI-powered features and applications. -To try out this sample app, you'll need a modern web browser and a local http server. +To try out this sample app, run `npm i` and then `API_KEY={your API key} npm start`. ## Requirements diff --git a/samples/web/utils/shared.js b/samples/web/utils/shared.js index d420a876..7469baa6 100644 --- a/samples/web/utils/shared.js +++ b/samples/web/utils/shared.js @@ -27,7 +27,7 @@ import { marked } from "https://esm.run/marked"; export async function getGenerativeModel(params) { // Fetch API key from server // If you need a new API key, get it from https://makersuite.google.com/app/apikey - const API_KEY = await (await fetch("API_KEY")).text(); + const API_KEY = await (await fetch("/API_KEY")).text(); const genAI = new GoogleGenerativeAI(API_KEY); diff --git a/src/models/generative-model.test.ts b/src/models/generative-model.test.ts index bfaabcd7..e6c6fd85 100644 --- a/src/models/generative-model.test.ts +++ b/src/models/generative-model.test.ts @@ -22,6 +22,7 @@ import { FunctionCallingMode, HarmBlockThreshold, HarmCategory, + ObjectSchema, SchemaType, } from "../../types"; import { getMockResponse } from "../../test-utils/mock-response"; @@ -60,7 +61,6 @@ describe("GenerativeModel", () => { properties: { testField: { type: SchemaType.STRING, - properties: {}, }, }, }, @@ -93,7 +93,8 @@ describe("GenerativeModel", () => { SchemaType.OBJECT, ); expect( - genModel.generationConfig?.responseSchema.properties.testField.type, + (genModel.generationConfig?.responseSchema as ObjectSchema).properties + .testField.type, ).to.equal(SchemaType.STRING); expect(genModel.generationConfig?.presencePenalty).to.equal(0.6); expect(genModel.generationConfig?.frequencyPenalty).to.equal(0.5); @@ -172,7 +173,6 @@ describe("GenerativeModel", () => { properties: { testField: { type: SchemaType.STRING, - properties: {}, }, }, }, @@ -206,7 +206,6 @@ describe("GenerativeModel", () => { properties: { newTestField: { type: SchemaType.STRING, - properties: {}, }, }, }, @@ -332,7 +331,6 @@ describe("GenerativeModel", () => { properties: { testField: { type: SchemaType.STRING, - properties: {}, }, }, }, @@ -340,8 +338,10 @@ describe("GenerativeModel", () => { systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, }); expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); - expect(genModel.generationConfig.responseSchema.properties.testField).to - .exist; + expect( + (genModel.generationConfig.responseSchema as ObjectSchema).properties + .testField, + ).to.exist; const mockResponse = getMockResponse( "unary-success-basic-reply-short.json", ); @@ -372,7 +372,6 @@ describe("GenerativeModel", () => { properties: { testField: { type: SchemaType.STRING, - properties: {}, }, }, }, @@ -381,8 +380,10 @@ describe("GenerativeModel", () => { toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, }); - expect(genModel.generationConfig.responseSchema.properties.testField).to - .exist; + expect( + (genModel.generationConfig.responseSchema as ObjectSchema).properties + .testField, + ).to.exist; expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( FunctionCallingMode.NONE, @@ -403,7 +404,6 @@ describe("GenerativeModel", () => { properties: { newTestField: { type: SchemaType.STRING, - properties: {}, }, }, }, diff --git a/types/function-calling.ts b/types/function-calling.ts index 02bdb5fc..1d0da57d 100644 --- a/types/function-calling.ts +++ b/types/function-calling.ts @@ -105,28 +105,100 @@ export enum SchemaType { * More fields may be added in the future as needed. * @public */ -export interface Schema { - /** - * Optional. The type of the property. {@link - * SchemaType}. - */ - type?: SchemaType; - /** Optional. The format of the property. */ - format?: string; - /** Optional. The description of the property. */ +export type Schema = + | StringSchema + | NumberSchema + | IntegerSchema + | BooleanSchema + | ArraySchema + | ObjectSchema; + +/** + * Fields common to all Schema types. + */ +export interface BaseSchema { + /** Optional. Description of the value. */ description?: string; - /** Optional. Whether the property is nullable. */ + /** If true, the value can be null. */ nullable?: boolean; - /** Optional. The items of the property. */ - items?: Schema; - /** Optional. The enum of the property. */ + + // The field 'example' is accepted, but in testing, it seems like it accepts + // any value of any type, even when that doesn't match the type that the + // schema describes, and it doesn't appear to affect the model's output. +} + +/** + * Describes a JSON-encodable floating point number. + */ +export interface NumberSchema extends BaseSchema { + type: typeof SchemaType.NUMBER; + /** Optional. The format of the number. */ + format?: "float" | "double"; + + // Note that the API accepts `minimum` and `maximum` fields here, as numbers, + // but when tested they had no effect. +} + +/** + * Describes a JSON-encodable integer. + */ +export interface IntegerSchema extends BaseSchema { + type: typeof SchemaType.INTEGER; + /** Optional. The format of the number. */ + format?: "int32" | "int64"; // server rejects int32 or int64 + + // Note that the API accepts minimum and maximum fields here, as numbers, + // but when tested they had no effect. +} + +/** + * Describes a string. + */ +export interface StringSchema extends BaseSchema { + type: typeof SchemaType.STRING; + /** If present, limits the result to one of the given values. */ enum?: string[]; - /** Optional. Map of {@link Schema}. */ - properties?: { [k: string]: Schema }; - /** Optional. Array of required property. */ + // Note that the API accepts the `pattern`, `minLength`, and `maxLength` + // fields, but they may only be advisory. + // The `format` field is not (at time of writing) supported on strings. +} + +/** + * Describes a boolean, either 'true' or 'false'. + */ +export interface BooleanSchema extends BaseSchema { + type: typeof SchemaType.BOOLEAN; +} + +/** + * Describes an array, an ordered list of values. + */ +export interface ArraySchema extends BaseSchema { + type: typeof SchemaType.ARRAY; + /** A schema describing the entries in the array. */ + items: Schema; + + /** The minimum number of items in the array. */ + minItems?: number; + /** The maximum number of items in the array. */ + maxItems?: number; +} + +/** + * Describes a JSON object, a mapping of specific keys to values. + */ +export interface ObjectSchema extends BaseSchema { + type: typeof SchemaType.OBJECT; + /** Describes the properties of the JSON object. Must not be empty. */ + properties: { [k: string]: Schema }; + /** + * A list of keys declared in the properties object. + * Required properties will always be present in the generated object. + */ required?: string[]; - /** Optional. The example of the property. */ - example?: unknown; + + // Note that the API accepts the `minProperties`, and `maxProperties` fields, + // but they may only be advisory. } /** @@ -148,13 +220,13 @@ export interface FunctionDeclarationSchema { * Schema for top-level function declaration * @public */ -export interface FunctionDeclarationSchemaProperty extends Schema {} +export type FunctionDeclarationSchemaProperty = Schema; /** * Schema passed to `GenerationConfig.responseSchema` * @public */ -export interface ResponseSchema extends Schema {} +export type ResponseSchema = Schema; /** * Tool config. This config is shared for all tools provided in the request.