Skip to content

Commit

Permalink
Make Schema a discriminated union
Browse files Browse the repository at this point in the history
This leverages the type system to better describe the API's requirements for schemas. For example, rather than saying that any schema might have an optional `items` property, we're able to express that `items` is required on array schemas and forbidden on all others.

More info on discriminated unions: https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions
  • Loading branch information
rictic committed Nov 13, 2024
1 parent 6ec2c27 commit f7d75d5
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 68 deletions.
5 changes: 5 additions & 0 deletions .changeset/young-rivers-shout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

The schema types are now more specific, using a [discriminated union](https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions) based on the 'type' field to more accurately define which fields are allowed.
72 changes: 55 additions & 17 deletions common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@
```ts

// @public
export interface ArraySchema extends BaseSchema {
items: Schema;
maxItems?: number;
minItems?: number;
// (undocumented)
type: typeof SchemaType.ARRAY;
}

// @public
export interface BaseSchema {
description?: string;
nullable?: boolean;
}

// @public
export interface BooleanSchema extends BaseSchema {
// (undocumented)
type: typeof SchemaType.BOOLEAN;
}

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down Expand Up @@ -286,8 +307,7 @@ export interface FunctionDeclarationSchema {
}

// @public
export interface FunctionDeclarationSchemaProperty extends Schema {
}
export type FunctionDeclarationSchemaProperty = Schema;

// @public
export interface FunctionDeclarationsTool {
Expand Down Expand Up @@ -368,6 +388,13 @@ export interface InlineDataPart {
text?: never;
}

// @public
export interface IntegerSchema extends BaseSchema {
format?: "int32" | "int64";
// (undocumented)
type: typeof SchemaType.INTEGER;
}

// @public (undocumented)
export interface ListCacheResponse {
// (undocumented)
Expand All @@ -392,6 +419,23 @@ export interface ListParams {
pageToken?: string;
}

// @public
export interface NumberSchema extends BaseSchema {
format?: "float" | "double";
// (undocumented)
type: typeof SchemaType.NUMBER;
}

// @public
export interface ObjectSchema extends BaseSchema {
properties: {
[k: string]: Schema;
};
required?: string[];
// (undocumented)
type: typeof SchemaType.OBJECT;
}

// @public
export enum Outcome {
OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded",
Expand All @@ -413,8 +457,7 @@ export interface RequestOptions {
}

// @public
export interface ResponseSchema extends Schema {
}
export type ResponseSchema = Schema;

// @public
export interface RpcStatus {
Expand All @@ -424,19 +467,7 @@ export interface RpcStatus {
}

// @public
export interface Schema {
description?: string;
enum?: string[];
example?: unknown;
format?: string;
items?: Schema;
nullable?: boolean;
properties?: {
[k: string]: Schema;
};
required?: string[];
type?: SchemaType;
}
export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema;

// @public
export enum SchemaType {
Expand All @@ -453,6 +484,13 @@ export interface SingleRequestOptions extends RequestOptions {
signal?: AbortSignal;
}

// @public
export interface StringSchema extends BaseSchema {
enum?: string[];
// (undocumented)
type: typeof SchemaType.STRING;
}

// @public
export interface TextPart {
// (undocumented)
Expand Down
72 changes: 55 additions & 17 deletions common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
```ts

// @public
export interface ArraySchema extends BaseSchema {
items: Schema;
maxItems?: number;
minItems?: number;
// (undocumented)
type: typeof SchemaType.ARRAY;
}

// @public
export interface BaseParams {
// (undocumented)
Expand All @@ -12,6 +21,12 @@ export interface BaseParams {
safetySettings?: SafetySetting[];
}

// @public
export interface BaseSchema {
description?: string;
nullable?: boolean;
}

// @public
export interface BatchEmbedContentsRequest {
// (undocumented)
Expand All @@ -34,6 +49,12 @@ export enum BlockReason {
SAFETY = "SAFETY"
}

// @public
export interface BooleanSchema extends BaseSchema {
// (undocumented)
type: typeof SchemaType.BOOLEAN;
}

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down Expand Up @@ -355,8 +376,7 @@ export interface FunctionDeclarationSchema {
}

// @public
export interface FunctionDeclarationSchemaProperty extends Schema {
}
export type FunctionDeclarationSchemaProperty = Schema;

// @public
export interface FunctionDeclarationsTool {
Expand Down Expand Up @@ -645,6 +665,13 @@ export interface InlineDataPart {
text?: never;
}

// @public
export interface IntegerSchema extends BaseSchema {
format?: "int32" | "int64";
// (undocumented)
type: typeof SchemaType.INTEGER;
}

// @public
export interface LogprobsCandidate {
logProbability: number;
Expand Down Expand Up @@ -672,6 +699,23 @@ export interface ModelParams extends BaseParams {
tools?: Tool[];
}

// @public
export interface NumberSchema extends BaseSchema {
format?: "float" | "double";
// (undocumented)
type: typeof SchemaType.NUMBER;
}

// @public
export interface ObjectSchema extends BaseSchema {
properties: {
[k: string]: Schema;
};
required?: string[];
// (undocumented)
type: typeof SchemaType.OBJECT;
}

// @public
export enum Outcome {
OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded",
Expand Down Expand Up @@ -706,8 +750,7 @@ export interface RequestOptions {
}

// @public
export interface ResponseSchema extends Schema {
}
export type ResponseSchema = Schema;

// @public
export interface RetrievalMetadata {
Expand All @@ -731,19 +774,7 @@ export interface SafetySetting {
}

// @public
export interface Schema {
description?: string;
enum?: string[];
example?: unknown;
format?: string;
items?: Schema;
nullable?: boolean;
properties?: {
[k: string]: Schema;
};
required?: string[];
type?: SchemaType;
}
export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema;

// @public
export enum SchemaType {
Expand Down Expand Up @@ -779,6 +810,13 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}

// @public
export interface StringSchema extends BaseSchema {
enum?: string[];
// (undocumented)
type: typeof SchemaType.STRING;
}

// @public
export enum TaskType {
// (undocumented)
Expand Down
9 changes: 8 additions & 1 deletion rollup.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ import replace from "rollup-plugin-replace";
import typescriptPlugin from "rollup-plugin-typescript2";
import typescript from "typescript";
import json from "@rollup/plugin-json";
import pkg from "./package.json" assert { type: "json" };
import * as fs from "node:fs";
import * as path from "node:path";
import { fileURLToPath } from "node:url";

const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);

const pkg = JSON.parse(fs.readFileSync(path.join(__dirname, "package.json")));

const es2017BuildPlugins = [
typescriptPlugin({
Expand Down
2 changes: 1 addition & 1 deletion samples/web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This sample app demonstrates how to use state-of-the-art
generative AI models (like Gemini) to build AI-powered features and applications.

To try out this sample app, you'll need a modern web browser and a local http server.
To try out this sample app, run `npm i` and then `API_KEY={your API key} npm start`.

## Requirements

Expand Down
2 changes: 1 addition & 1 deletion samples/web/utils/shared.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { marked } from "https://esm.run/marked";
export async function getGenerativeModel(params) {
// Fetch API key from server
// If you need a new API key, get it from https://makersuite.google.com/app/apikey
const API_KEY = await (await fetch("API_KEY")).text();
const API_KEY = await (await fetch("/API_KEY")).text();

const genAI = new GoogleGenerativeAI(API_KEY);

Expand Down
22 changes: 11 additions & 11 deletions src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
FunctionCallingMode,
HarmBlockThreshold,
HarmCategory,
ObjectSchema,
SchemaType,
} from "../../types";
import { getMockResponse } from "../../test-utils/mock-response";
Expand Down Expand Up @@ -60,7 +61,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -93,7 +93,8 @@ describe("GenerativeModel", () => {
SchemaType.OBJECT,
);
expect(
genModel.generationConfig?.responseSchema.properties.testField.type,
(genModel.generationConfig?.responseSchema as ObjectSchema).properties
.testField.type,
).to.equal(SchemaType.STRING);
expect(genModel.generationConfig?.presencePenalty).to.equal(0.6);
expect(genModel.generationConfig?.frequencyPenalty).to.equal(0.5);
Expand Down Expand Up @@ -172,7 +173,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -206,7 +206,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -332,16 +331,17 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
Expand Down Expand Up @@ -372,7 +372,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand All @@ -381,8 +380,10 @@ describe("GenerativeModel", () => {
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -403,7 +404,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down
Loading

0 comments on commit f7d75d5

Please sign in to comment.