Skip to content

Commit

Permalink
Input validation for threads and streaming (#32057)
Browse files Browse the repository at this point in the history
Co-authored-by: Zachary King <[email protected]>
  • Loading branch information
ZachhK and Zachary King authored Dec 5, 2024
1 parent 419844f commit ba34261
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 20 deletions.
63 changes: 52 additions & 11 deletions sdk/ai/ai-projects/src/agents/inputValidations.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { ToolDefinition, VectorStoreDataSource } from "../generated/src/models.js";

enum Tools {
CodeInterpreter = "code_interpreter",
FileSearch = "file_search",
Function = "function",
BingGrounding = "bing_grounding",
MicrosoftFabric = "microsoft_fabric",
SharepointGrounding = "sharepoint_grounding",
AzureAISearch = "azure_ai_search",
}
import { ToolDefinition, UpdateToolResourcesOptions, VectorStoreDataSource } from "../generated/src/models.js";

export function validateVectorStoreDataType(data_sources: VectorStoreDataSource[]): void {
if (!data_sources.some(value => !["uri_asset", "id_asset"].includes(value.type))) {
Expand Down Expand Up @@ -43,6 +33,16 @@ export function validateOrder(order: string): void {
}
}

enum Tools {
CodeInterpreter = "code_interpreter",
FileSearch = "file_search",
Function = "function",
BingGrounding = "bing_grounding",
MicrosoftFabric = "microsoft_fabric",
SharepointGrounding = "sharepoint_grounding",
AzureAISearch = "azure_ai_search",
}

export function validateTools(value: Array<ToolDefinition>): void {
if (value.some(tool => !Object.values(Tools).includes(tool as unknown as Tools))) {
throw new Error("Tool type must be one of 'code_interpreter', 'file_search', 'function', 'bing_grounding', 'microsoft_fabric', 'sharepoint_grounding', 'azure_ai_search'");
Expand All @@ -61,6 +61,24 @@ export function validateMetadata(metadata: Record<string, string>): void {
}
}

export function validateToolResources(toolResource: UpdateToolResourcesOptions): void {
if (toolResource.code_interpreter) {
if (toolResource.code_interpreter.file_ids && toolResource.code_interpreter.file_ids.length > 20) {
throw new Error("A maximum of 20 file IDs are allowed");
}
}
if (toolResource.file_search) {
if (toolResource.file_search.vector_store_ids && toolResource.file_search.vector_store_ids.length > 1) {
throw new Error("Only one vector store ID is allowed");
}
}
if (toolResource.azure_ai_search) {
if (toolResource.azure_ai_search.indexes && toolResource.azure_ai_search.indexes.length > 1) {
throw new Error("Only one index is allowed");
}
}
}

export function validateVectorStoreId(vectorStoreId: string): void {
if (!vectorStoreId) {
throw new Error("Vector store ID is required");
Expand All @@ -85,3 +103,26 @@ export function validateFileStatusFilter(filter: string): void {
throw new Error("File status filter must be one of 'in_progress', 'completed', 'failed', 'cancelled'");
}
}

enum Messages {
User = "user",
Assistants = "assistant",
}

export function validateMessages(value: string): void {
if (!Object.values(Messages).includes(value as Messages)) {
throw new Error("Role must be either 'user' or 'assistant'");
}
}


enum TruncationStrategy {
Auto = "auto",
LastMessages = "last_messages",
}

export function validateTruncationStrategy(value: string): void {
if (!Object.values(TruncationStrategy).includes(value as TruncationStrategy)) {
throw new Error("Role must be either 'auto' or 'last_messages'");
}
}
14 changes: 7 additions & 7 deletions sdk/ai/ai-projects/src/agents/runs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import { Client, createRestError } from "@azure-rest/core-client";
import { CancelRunParameters, CreateRunParameters, CreateThreadAndRunParameters, GetRunParameters, ListRunsParameters, SubmitToolOutputsToRunParameters, UpdateRunParameters } from "../generated/src/parameters.js";
import { OpenAIPageableListOfThreadRunOutput, ThreadRunOutput } from "../generated/src/outputModels.js";
import { validateLimit, validateMetadata, validateOrder, validateRunId, validateThreadId, validateTools } from "./inputValidations.js";
import { validateLimit, validateMessages, validateMetadata, validateOrder, validateRunId, validateThreadId, validateTools, validateTruncationStrategy } from "./inputValidations.js";

const expectedStatuses = ["200"];

Expand Down Expand Up @@ -150,8 +150,8 @@ function validateUpdateRunParameters(thread_id: string, run_id: string, options?
}

function validateCreateRunParameters(options: CreateRunParameters| CreateThreadAndRunParameters): void {
if ('additional_messages' in options.body && options.body.additional_messages && options.body.additional_messages.some(value => !["user", "assistant"].includes(value.role))) {
throw new Error("Role must be either 'user' or 'assistant'");
if ('additional_messages' in options.body && options.body.additional_messages) {
options.body.additional_messages.forEach(message => validateMessages(message.role));
}
if (options.body.tools) {
validateTools(options.body.tools);
Expand All @@ -162,8 +162,8 @@ function validateCreateRunParameters(options: CreateRunParameters| CreateThreadA
if (options.body.tool_choice && typeof options.body.tool_choice !== 'string') {
validateTools([options.body.tool_choice]);
}
if (options.body.truncation_strategy?.type && !["auto", "last_messages"].includes(options.body.truncation_strategy.type)) {
throw new Error("Role must be either 'auto' or 'last_messages'");
if (options.body.truncation_strategy?.type) {
validateTruncationStrategy(options.body.truncation_strategy.type);
}
if (options.body.metadata) {
validateMetadata(options.body.metadata);
Expand All @@ -172,8 +172,8 @@ function validateCreateRunParameters(options: CreateRunParameters| CreateThreadA

function validateCreateThreadAndRunParameters(options: CreateThreadAndRunParameters): void {
validateCreateRunParameters(options);
if (options.body.thread?.messages && options.body.thread.messages.some(value => !["user", "assistant"].includes(value.role))) {
throw new Error("Role must be either 'user' or 'assistant'");
if (options.body.thread?.messages) {
options.body.thread?.messages.forEach(message => validateMessages(message.role));
}
if (options.body.tools) {
validateTools(options.body.tools);
Expand Down
41 changes: 39 additions & 2 deletions sdk/ai/ai-projects/src/agents/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ import { AgentEventMessage, AgentEventMessageStream } from "./streamingModels.js
import { createSseStream, EventMessageStream } from "@azure/core-sse";
import { isNodeLike } from "@azure/core-util";
import { IncomingMessage } from "http";
import { validateMessages, validateMetadata, validateRunId, validateThreadId, validateToolResources, validateTools, validateTruncationStrategy } from "./inputValidations.js";

const expectedStatuses = ["200"];



function createAgentStream(stream: EventMessageStream): AgentEventMessageStream {
const asyncIterator = toAsyncIterable(stream);
const asyncDisposable = stream as AsyncDisposable;
Expand Down Expand Up @@ -51,6 +50,8 @@ export async function createRunStreaming(
threadId: string,
options: CreateRunParameters,
): Promise<AgentEventMessageStream> {
validateThreadId(threadId);
validateCreateThreadAndRunBodyParam(options);
options.body.stream = true;

return processStream(context
Expand All @@ -64,6 +65,7 @@ export async function createThreadAndRunStreaming(
context: Client,
options: CreateThreadAndRunBodyParam,
): Promise<AgentEventMessageStream> {
validateCreateThreadAndRunBodyParam(options);
options.body.stream = true;
return processStream(context
.path("/threads/runs")
Expand All @@ -76,8 +78,43 @@ export async function submitToolOutputsToRunStreaming(
runId: string,
options: SubmitToolOutputsToRunParameters,
): Promise<AgentEventMessageStream> {
validateThreadId(threadId);
validateRunId(runId);
options.body.stream = true;

return processStream(context.path("/threads/{threadId}/runs/{runId}/submit_tool_outputs", threadId, runId)
.post(options));
}


function validateCreateThreadAndRunBodyParam(options: CreateRunParameters| CreateThreadAndRunBodyParam): void {
if ('additional_messages' in options.body && options.body.additional_messages) {
options.body.additional_messages.forEach(message => validateMessages(message.role));
}
if ( 'thread' in options.body && options.body.thread?.messages) {
options.body.thread?.messages.forEach(message => validateMessages(message.role));
}
if (options.body.tools) {
validateTools(options.body.tools);
}
if ('tool_resources' in options.body && options?.body.tool_resources) {
validateToolResources(options.body.tool_resources);
}
if (options.body.temperature && (options.body.temperature < 0 || options.body.temperature > 2)) {
throw new Error("Temperature must be between 0 and 2");
}
if (options.body.tool_choice && typeof options.body.tool_choice !== 'string') {
validateTools([options.body.tool_choice]);
}
if (options.body.truncation_strategy?.type) {
validateTruncationStrategy(options.body.truncation_strategy.type);
}
if (options.body.response_format){
if (!["json", "text"].includes(options.body.response_format as string)) {
throw new Error("Response format must be either 'json' or 'text'");
}
}
if (options?.body.metadata){
validateMetadata(options.body.metadata);
}
}
28 changes: 28 additions & 0 deletions sdk/ai/ai-projects/src/agents/threads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { Client, createRestError } from "@azure-rest/core-client";
import { CreateThreadParameters, DeleteThreadParameters, GetThreadParameters, UpdateThreadParameters } from "../generated/src/parameters.js";
import { AgentThreadOutput, ThreadDeletionStatusOutput } from "../generated/src/outputModels.js";
import { validateMessages, validateMetadata, validateThreadId, validateToolResources } from "./inputValidations.js";

const expectedStatuses = ["200"];

Expand All @@ -12,6 +13,7 @@ export async function createThread(
context: Client,
options?: CreateThreadParameters,
): Promise<AgentThreadOutput> {
validateCreateThreadParameters(options);
const result = await context.path("/threads").post(options);
if (!expectedStatuses.includes(result.status)) {
throw createRestError(result);
Expand All @@ -25,6 +27,7 @@ export async function getThread(
threadId: string,
options?: GetThreadParameters,
): Promise<AgentThreadOutput> {
validateThreadId(threadId);
const result = await context
.path("/threads/{threadId}", threadId)
.get(options);
Expand All @@ -40,6 +43,7 @@ export async function updateThread(
threadId: string,
options?: UpdateThreadParameters,
): Promise<AgentThreadOutput> {
validateUpdateThreadParameters(threadId, options);
const result = await context
.path("/threads/{threadId}", threadId)
.post(options);
Expand All @@ -55,6 +59,7 @@ export async function deleteThread(
threadId: string,
options?: DeleteThreadParameters,
): Promise<ThreadDeletionStatusOutput> {
validateThreadId(threadId);
const result = await context
.path("/threads/{threadId}", threadId)
.delete(options);
Expand All @@ -63,3 +68,26 @@ export async function deleteThread(
}
return result.body;
}


function validateCreateThreadParameters(options?: CreateThreadParameters): void {
if (options?.body.messages) {
options.body.messages.forEach(message => validateMessages(message.role));
}
if (options?.body.tool_resources) {
validateToolResources(options.body.tool_resources);
}
if (options?.body.metadata){
validateMetadata(options.body.metadata);
}
}

function validateUpdateThreadParameters(threadId: string, options?: UpdateThreadParameters): void {
validateThreadId(threadId);
if (options?.body.tool_resources) {
validateToolResources(options.body.tool_resources);
}
if (options?.body.metadata){
validateMetadata(options.body.metadata);
}
}

0 comments on commit ba34261

Please sign in to comment.