Skip to content

Commit

Permalink
Input validation for runs (#32030)
Browse files Browse the repository at this point in the history
Co-authored-by: Zachary King <[email protected]>
  • Loading branch information
ZachhK and Zachary King authored Dec 4, 2024
1 parent fda5672 commit 419844f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
30 changes: 29 additions & 1 deletion sdk/ai/ai-projects/src/agents/inputValidations.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { VectorStoreDataSource } from "../generated/src/models.js";
import { ToolDefinition, VectorStoreDataSource } from "../generated/src/models.js";

enum Tools {
CodeInterpreter = "code_interpreter",
FileSearch = "file_search",
Function = "function",
BingGrounding = "bing_grounding",
MicrosoftFabric = "microsoft_fabric",
SharepointGrounding = "sharepoint_grounding",
AzureAISearch = "azure_ai_search",
}

export function validateVectorStoreDataType(data_sources: VectorStoreDataSource[]): void {
if (!data_sources.some(value => !["uri_asset", "id_asset"].includes(value.type))) {
throw new Error("Vector store data type must be one of 'uri_asset', 'id_asset'");
}
}

export function validateThreadId(threadId: string): void {
if (!threadId) {
throw new Error("Thread ID is required");
}
}

export function validateRunId(runId: string): void {
if (!runId) {
throw new Error("Run ID is required");
}
}

export function validateLimit(limit: number): void {
if (limit < 1 || limit > 100) {
throw new Error("Limit must be between 1 and 100");
Expand All @@ -21,6 +43,12 @@ export function validateOrder(order: string): void {
}
}

export function validateTools(value: Array<ToolDefinition>): void {
if (value.some(tool => !Object.values(Tools).includes(tool as unknown as Tools))) {
throw new Error("Tool type must be one of 'code_interpreter', 'file_search', 'function', 'bing_grounding', 'microsoft_fabric', 'sharepoint_grounding', 'azure_ai_search'");
}
}

export function validateMetadata(metadata: Record<string, string>): void {
if (Object.keys(metadata).length > 16) {
throw new Error("Only 16 key/value pairs are allowed");
Expand Down
27 changes: 27 additions & 0 deletions sdk/ai/ai-projects/src/agents/runSteps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { Client, createRestError } from "@azure-rest/core-client";
import { OpenAIPageableListOfRunStepOutput, RunStepOutput } from "../generated/src/outputModels.js";
import { GetRunStepParameters, ListRunStepsParameters } from "../generated/src/parameters.js";
import { validateLimit, validateOrder, validateRunId, validateThreadId } from "./inputValidations.js";

const expectedStatuses = ["200"];

Expand All @@ -15,6 +16,9 @@ export async function getRunStep(
stepId: string,
options?: GetRunStepParameters,
): Promise<RunStepOutput> {
validateThreadId(threadId);
validateRunId(runId);
validateStepId(stepId);
const result = await context
.path("/threads/{threadId}/runs/{runId}/steps/{stepId}", threadId, runId, stepId)
.get(options);
Expand All @@ -31,6 +35,7 @@ export async function listRunSteps(
runId: string,
options?: ListRunStepsParameters,
): Promise<OpenAIPageableListOfRunStepOutput> {
validateListRunsParameters(threadId, runId, options);
const result = await context
.path("/threads/{threadId}/runs/{runId}/steps", threadId, runId)
.get(options);
Expand All @@ -39,3 +44,25 @@ export async function listRunSteps(
}
return result.body;
}



function validateStepId(stepId: string): void {
if (!stepId) {
throw new Error("Step ID is required");
}
}

function validateListRunsParameters(thread_id: string, runId: string, options?: ListRunStepsParameters): void {
validateThreadId(thread_id);
validateRunId(runId);
if (options?.queryParameters?.limit && (options.queryParameters.limit < 1 || options.queryParameters.limit > 100)) {
throw new Error("Limit must be between 1 and 100");
}
if (options?.queryParameters?.limit) {
validateLimit(options.queryParameters.limit);
}
if (options?.queryParameters?.order) {
validateOrder(options.queryParameters.order);
}
}
81 changes: 81 additions & 0 deletions sdk/ai/ai-projects/src/agents/runs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { Client, createRestError } from "@azure-rest/core-client";
import { CancelRunParameters, CreateRunParameters, CreateThreadAndRunParameters, GetRunParameters, ListRunsParameters, SubmitToolOutputsToRunParameters, UpdateRunParameters } from "../generated/src/parameters.js";
import { OpenAIPageableListOfThreadRunOutput, ThreadRunOutput } from "../generated/src/outputModels.js";
import { validateLimit, validateMetadata, validateOrder, validateRunId, validateThreadId, validateTools } from "./inputValidations.js";

const expectedStatuses = ["200"];

Expand All @@ -13,6 +14,8 @@ export async function createRun(
threadId: string,
options: CreateRunParameters,
): Promise<ThreadRunOutput> {
validateThreadId(threadId);
validateCreateRunParameters(options);
options.body.stream = false;
const result = await context
.path("/threads/{threadId}/runs", threadId)
Expand All @@ -29,6 +32,7 @@ export async function listRuns(
threadId: string,
options?: ListRunsParameters,
): Promise<OpenAIPageableListOfThreadRunOutput> {
validateListRunsParameters(threadId, options);
const result = await context
.path("/threads/{threadId}/runs", threadId)
.get(options);
Expand All @@ -45,6 +49,8 @@ export async function getRun(
runId: string,
options?: GetRunParameters,
): Promise<ThreadRunOutput> {
validateThreadId(threadId);
validateRunId(runId);
const result = await context
.path("/threads/{threadId}/runs/{runId}", threadId, runId)
.get(options);
Expand All @@ -61,6 +67,7 @@ export async function updateRun(
runId: string,
options?: UpdateRunParameters,
): Promise<ThreadRunOutput> {
validateUpdateRunParameters(threadId, runId, options);
const result = await context
.path("/threads/{threadId}/runs/{runId}", threadId, runId)
.post(options);
Expand All @@ -77,6 +84,8 @@ export async function submitToolOutputsToRun(
runId: string,
options: SubmitToolOutputsToRunParameters,
): Promise<ThreadRunOutput> {
validateThreadId(threadId);
validateRunId(runId);
options.body.stream = false;
const result = await context
.path("/threads/{threadId}/runs/{runId}/submit_tool_outputs", threadId, runId)
Expand All @@ -94,6 +103,8 @@ export async function cancelRun(
runId: string,
options?: CancelRunParameters,
): Promise<ThreadRunOutput> {
validateThreadId(threadId);
validateRunId(runId);
const result = await context
.path("/threads/{threadId}/runs/{runId}/cancel", threadId, runId)
.post(options);
Expand All @@ -108,10 +119,80 @@ export async function createThreadAndRun(
context: Client,
options: CreateThreadAndRunParameters,
): Promise<ThreadRunOutput> {
validateCreateThreadAndRunParameters(options);
options.body.stream = false;
const result = await context.path("/threads/runs").post(options);
if (!expectedStatuses.includes(result.status)) {
throw createRestError(result);
}
return result.body;
}

function validateListRunsParameters(thread_id: string, options?: ListRunsParameters): void {
validateThreadId(thread_id);
if (options?.queryParameters?.limit && (options.queryParameters.limit < 1 || options.queryParameters.limit > 100)) {
throw new Error("Limit must be between 1 and 100");
}
if (options?.queryParameters?.limit) {
validateLimit(options.queryParameters.limit);
}
if (options?.queryParameters?.order) {
validateOrder(options.queryParameters.order);
}
}

function validateUpdateRunParameters(thread_id: string, run_id: string, options?: UpdateRunParameters): void {
validateThreadId(thread_id);
validateRunId(run_id);
if(options?.body.metadata){
validateMetadata(options.body.metadata);
}
}

function validateCreateRunParameters(options: CreateRunParameters| CreateThreadAndRunParameters): void {
if ('additional_messages' in options.body && options.body.additional_messages && options.body.additional_messages.some(value => !["user", "assistant"].includes(value.role))) {
throw new Error("Role must be either 'user' or 'assistant'");
}
if (options.body.tools) {
validateTools(options.body.tools);
}
if (options.body.temperature && (options.body.temperature < 0 || options.body.temperature > 2)) {
throw new Error("Temperature must be between 0 and 2");
}
if (options.body.tool_choice && typeof options.body.tool_choice !== 'string') {
validateTools([options.body.tool_choice]);
}
if (options.body.truncation_strategy?.type && !["auto", "last_messages"].includes(options.body.truncation_strategy.type)) {
throw new Error("Role must be either 'auto' or 'last_messages'");
}
if (options.body.metadata) {
validateMetadata(options.body.metadata);
}
}

function validateCreateThreadAndRunParameters(options: CreateThreadAndRunParameters): void {
validateCreateRunParameters(options);
if (options.body.thread?.messages && options.body.thread.messages.some(value => !["user", "assistant"].includes(value.role))) {
throw new Error("Role must be either 'user' or 'assistant'");
}
if (options.body.tools) {
validateTools(options.body.tools);
}
if (options.body.tool_resources?.code_interpreter) {
if (options.body.tool_resources.code_interpreter) {
if (options.body.tool_resources.code_interpreter.file_ids && options.body.tool_resources.code_interpreter.file_ids.length > 20) {
throw new Error("A maximum of 20 file IDs are allowed");
}
}
if (options.body.tool_resources.file_search) {
if (options.body.tool_resources.file_search.vector_store_ids && options.body.tool_resources.file_search.vector_store_ids.length > 1) {
throw new Error("Only one vector store ID is allowed");
}
}
if (options.body.tool_resources.azure_ai_search) {
if (options.body.tool_resources.azure_ai_search.indexes && options.body.tool_resources.azure_ai_search.indexes.length > 1) {
throw new Error("Only one index is allowed");
}
}
}
}

0 comments on commit 419844f

Please sign in to comment.