Skip to content

Commit

Permalink
[OpenAI] Fix Function Schemas Transformed into Snake Case (#28690)
Browse files Browse the repository at this point in the history
  • Loading branch information
minhanh-phan authored Feb 28, 2024
1 parent db56ecf commit 09f082c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sdk/openai/openai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "js",
"TagPrefix": "js/openai/openai",
"Tag": "js/openai/openai_987b2476c8"
"Tag": "js/openai/openai_6b73e03317"
}
6 changes: 4 additions & 2 deletions sdk/openai/openai/src/api/client/openAIClient/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -750,13 +750,14 @@ function _getChatCompletionsWithAzureExtensionsSend(
| GetChatCompletionsWithAzureExtensions200Response
| GetChatCompletionsWithAzureExtensionsDefaultResponse
> {
const { functions, functionCall, messages, dataSources, ...rest } = body;
const { functions, functionCall, messages, dataSources, tools, ...rest } = body;
return context
.path("/deployments/{deploymentId}/extensions/chat/completions", deploymentName)
.post({
...operationOptionsToRequestParameters(options),
body: {
...snakeCaseKeys(rest),
tools,
dataSources: dataSources?.map(
({ type, ...opts }) => ({ type, parameters: opts }) as AzureChatExtensionConfiguration,
),
Expand Down Expand Up @@ -792,11 +793,12 @@ function _getChatCompletionsSend(
body: GeneratedChatCompletionsOptions,
options: ClientOpenAIClientGetChatCompletionsOptions = { requestOptions: {} },
): StreamableMethod<GetChatCompletions200Response | GetChatCompletionsDefaultResponse> {
const { functions, functionCall, messages, ...rest } = body;
const { functions, functionCall, messages, tools, ...rest } = body;
return context.path("/deployments/{deploymentId}/chat/completions", deploymentName).post({
...operationOptionsToRequestParameters(options),
body: {
...snakeCaseKeys(rest),
tools,
functions,
function_call: functionCall,
messages: messages.map(serializeChatRequestMessage),
Expand Down
45 changes: 45 additions & 0 deletions sdk/openai/openai/test/public/completions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,51 @@ describe("OpenAI", function () {
);
});

it("ensure schema name is not transformed with snake case", async function () {
const getAssetInfo = {
name: "getAssetInfo",
description: "Returns information about an asset",
parameters: {
type: "object",
properties: {
assetName: {
type: "string",
description: "The asset name. This is a required parameter.",
},
},
required: ["assetName"],
},
};
updateWithSucceeded(
await withDeployments(
getSucceeded(
authMethod,
deployments,
models,
chatCompletionDeployments,
chatCompletionModels,
),
(deploymentName) =>
client.getChatCompletions(
deploymentName,
[{ role: "user", content: "Give me information about Asset No1" }],
{
tools: [{ type: "function", function: getAssetInfo }],
},
),
(res) => {
assertChatCompletions(res, { functions: true });
assert.isDefined(res.choices[0].message?.toolCalls);
const argument = res.choices[0].message?.toolCalls[0].function.arguments;
assert.isTrue(argument?.includes("assetName"));
},
),
chatCompletionDeployments,
chatCompletionModels,
authMethod,
);
});

it("respects json_object responseFormat", async function () {
if (authMethod !== "OpenAIKey") {
this.skip();
Expand Down
13 changes: 11 additions & 2 deletions sdk/openai/openai/test/public/utils/asserts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ function assertNonEmptyArray<T>(val: T[], validate: (x: T) => void): void {
}
}

function assertArray<T>(val: T[], validate: (x: T) => void): void {
assert.isArray(val);
for (const x of val) {
validate(x);
}
}

async function assertAsyncIterable<T>(
val: AsyncIterable<T>,
validate: (x: T) => void,
Expand Down Expand Up @@ -121,9 +128,11 @@ function assertContentFilterResultDetailsForPrompt(cfr: ContentFilterResultDetai
ifDefined(cfr.selfHarm, assertContentFilterResult);
ifDefined(cfr.sexual, assertContentFilterResult);
ifDefined(cfr.violence, assertContentFilterResult);
ifDefined(cfr.profanity, assertContentFilterResult);
ifDefined(cfr.profanity, assertContentFilterDetectionResult);
ifDefined(cfr.jailbreak, assertContentFilterDetectionResult);
ifDefined(cfr.customBlocklists, assertContentFilterBlocklistIdResult);
ifDefined(cfr.customBlocklists, (arr) =>
assertArray(arr, assertContentFilterBlocklistIdResult),
);
}
}

Expand Down

0 comments on commit 09f082c

Please sign in to comment.