Skip to content

Proposal to use method options cloning #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ protected AzureChatClient()
/// <inheritdoc/>
public override Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
PostfixSwapMaxTokens(ref options);
return base.CompleteChatAsync(messages, options, cancellationToken);
}

/// <inheritdoc/>
public override ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
PostfixSwapMaxTokens(ref options);
return base.CompleteChat(messages, options, cancellationToken);
}

Expand All @@ -64,16 +62,12 @@ public override CollectionResult<StreamingChatCompletionUpdate> CompleteChatStre
/// <inheritdoc/>
public override AsyncCollectionResult<StreamingChatCompletionUpdate> CompleteChatStreamingAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
PostfixClearStreamOptions(messages, ref options);
PostfixSwapMaxTokens(ref options);
return base.CompleteChatStreamingAsync(messages, options, cancellationToken);
}

/// <inheritdoc/>
public override CollectionResult<StreamingChatCompletionUpdate> CompleteChatStreaming(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
PostfixClearStreamOptions(messages, ref options);
PostfixSwapMaxTokens(ref options);
return base.CompleteChatStreaming(messages, options, cancellationToken);
}

Expand Down Expand Up @@ -162,4 +156,12 @@ private static void PostfixSwapMaxTokens(ref ChatCompletionOptions options)
}
}
}

internal override ChatCompletionOptions CreatePerCallOptions(ChatCompletionOptions userOptions, IEnumerable<ChatMessage> messages, bool stream = false)
{
ChatCompletionOptions copiedOptions = base.CreatePerCallOptions(userOptions, messages, stream);
PostfixClearStreamOptions(messages, ref copiedOptions);
PostfixSwapMaxTokens(ref copiedOptions);
return copiedOptions;
}
}
205 changes: 127 additions & 78 deletions .dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI.Chat;
using Azure.AI.OpenAI.Tests.Utils.Config;
Expand All @@ -32,35 +33,22 @@ public ChatTests(bool isAsync) : base(isAsync)
[Category("Smoke")]
public async Task DefaultUserAgentStringWorks()
{
using MockHttpMessageHandler pipeline = new(MockHttpMessageHandler.ReturnEmptyJson);

Uri endpoint = new Uri("https://www.bing.com/");
string apiKey = "not-a-real-one";
string model = "ignore";

AzureOpenAIClient topLevel = new(
endpoint,
new ApiKeyCredential(apiKey),
new AzureOpenAIClientOptions()
{
Transport = pipeline.Transport
});

ChatClient client = WrapClient(topLevel.GetChatClient(model));
using MockHttpMessageHandler messageHandler = new(MockHttpMessageHandler.ReturnEmptyJson);
ChatClient client = GetMockChatClient(messageHandler);

await client.CompleteChatAsync([new UserChatMessage("Hello")]);

Assert.That(pipeline.Requests, Is.Not.Empty);
Assert.That(messageHandler.Requests, Is.Not.Empty);

var request = pipeline.Requests[0];
var request = messageHandler.Requests[0];
Assert.That(request.Method, Is.EqualTo(HttpMethod.Post));
Assert.That(request.Uri?.GetLeftPart(UriPartial.Authority), Is.EqualTo(endpoint.GetLeftPart(UriPartial.Authority)));
Assert.That(request.Headers.GetValueOrDefault("api-key")?.FirstOrDefault(), Is.EqualTo(apiKey));
Assert.That(request.Uri?.GetLeftPart(UriPartial.Authority), Is.EqualTo(s_mockEndpoint.GetLeftPart(UriPartial.Authority)));
Assert.That(request.Headers.GetValueOrDefault("api-key")?.FirstOrDefault(), Is.EqualTo(s_mockApiKeyValue));
Assert.That(request.Headers.GetValueOrDefault("User-Agent")?.FirstOrDefault(), Does.Contain("azsdk-net-AI.OpenAI/"));
Assert.That(request.Content, Is.Not.Null);
var jsonString = request.Content.ToString();
Assert.That(jsonString, Is.Not.Null.Or.Empty);
Assert.That(jsonString, Does.Contain("\"messages\"").And.Contain("\"model\"").And.Contain(model));
Assert.That(jsonString, Does.Contain("\"messages\"").And.Contain("\"model\"").And.Contain(s_mockModelValue));
}

[Test]
Expand Down Expand Up @@ -146,33 +134,31 @@ public void DataSourceSerializationWorks()
[Category("Smoke")]
public async Task MaxTokensSerializationConfigurationWorks()
{
using MockHttpMessageHandler pipeline = new(MockHttpMessageHandler.ReturnEmptyJson);

Uri endpoint = new Uri("https://www.bing.com/");
string apiKey = "not-a-real-one";
string model = "ignore";

AzureOpenAIClient topLevel = new(
endpoint,
new ApiKeyCredential(apiKey),
new AzureOpenAIClientOptions()
{
Transport = pipeline.Transport
});
using MockHttpMessageHandler messageHandler = new(MockHttpMessageHandler.ReturnEmptyJson);
ChatClient client = GetMockChatClient(messageHandler);

ChatClient client = topLevel.GetChatClient(model);
AutoResetEvent newRequestEvent = new(false);
string? latestSerializedRequest = null;
messageHandler.OnRequest += (sender, request) =>
{
latestSerializedRequest = request.Content.ToString();
newRequestEvent.Set();
};

ChatCompletionOptions options = new();
bool GetSerializedOptionsContains(string value)

async Task<string> GetNextSerializedRequest()
{
BinaryData serialized = ModelReaderWriter.Write(options);
return serialized.ToString().Contains(value);
_ = await client.CompleteChatAsync(["Just mocking, no call here"], options);
newRequestEvent.WaitOne();
return latestSerializedRequest ?? string.Empty;
}

async Task AssertExpectedSerializationAsync(bool hasOldMaxTokens, bool hasNewMaxCompletionTokens)
{
_ = await client.CompleteChatAsync(["Just mocking, no call here"], options);
Assert.That(GetSerializedOptionsContains("max_tokens"), Is.EqualTo(hasOldMaxTokens));
Assert.That(GetSerializedOptionsContains("max_completion_tokens"), Is.EqualTo(hasNewMaxCompletionTokens));
string serializedRequest = await GetNextSerializedRequest();
Assert.That(serializedRequest.Contains("max_tokens"), Is.EqualTo(hasOldMaxTokens));
Assert.That(serializedRequest.Contains("max_completion_tokens"), Is.EqualTo(hasNewMaxCompletionTokens));
}

await AssertExpectedSerializationAsync(false, false);
Expand Down Expand Up @@ -324,7 +310,7 @@ public async Task RateLimitedRetryWorks(string headerName, string headerValue, d
Assert.That(observed429Delay!.Value.TotalMilliseconds, Is.LessThan(3 * expectedDelayMilliseconds + 2 * observed200Delay!.Value.TotalMilliseconds));
}

#endregion
#endregion

#region Regular chat completions tests

Expand Down Expand Up @@ -542,43 +528,6 @@ public async Task UserSecurityContextWorks()
Assert.That(completion, Is.Not.Null);
}

[RecordedTest]
[TestCase("chat", false)]
[TestCase("chat_o1", true)]
[TestCase("chat_o3-mini", true)]
public async Task MaxOutputTokensWorksAcrossModels(string testConfigName, bool useNewProperty)
{
IConfiguration testConfig = TestConfig.GetConfig(testConfigName)!;
ChatClient client = GetTestClient(testConfig);

ChatCompletionOptions options = new()
{
MaxOutputTokenCount = 16,
};

if (useNewProperty)
{
options.SetNewMaxCompletionTokensPropertyEnabled();
}

ChatCompletion completion = await client.CompleteChatAsync(
["Hello, world! Please write a funny haiku to greet me."],
options);
Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Length));

string serializedOptionsAfterUse = ModelReaderWriter.Write(options).ToString();

if (useNewProperty)
{
Assert.That(serializedOptionsAfterUse, Does.Contain("max_completion_tokens"));
Assert.That(serializedOptionsAfterUse, Does.Not.Contain("max_tokens"));
}
else
{
Assert.That(serializedOptionsAfterUse, Does.Not.Contain("max_completion_tokens"));
Assert.That(serializedOptionsAfterUse, Does.Contain("max_tokens"));
}
}
#endregion

#region Streaming chat completion tests
Expand Down Expand Up @@ -781,6 +730,89 @@ public async Task ChatWithO1Works()
Assert.That(completion, Is.Not.Null);
}

[Test]
[Category("Smoke")]
public async Task StreamOptionsResetAppropriately()
{
using MockHttpMessageHandler messageHandler = new(MockHttpMessageHandler.ReturnEmptyJson);
ChatClient client = GetMockChatClient(messageHandler);

AutoResetEvent newRequestEvent = new(false);
string? latestSerializedRequest = null;
messageHandler.OnRequest += (sender, request) =>
{
latestSerializedRequest = request.Content.ToString();
newRequestEvent.Set();
};
ChatCompletionOptions options = new();

string serializedOriginalOptions = ModelReaderWriter.Write(options).ToString();
void AssertSerializedOptionsUnchanged() => Assert.That(ModelReaderWriter.Write(options).ToString(), Is.EqualTo(serializedOriginalOptions));

// When not streaming, stream_options should not be present in the request

_ = await client.CompleteChatAsync(["Hello, mock"], options);
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Not.Contain("stream_options"));
AssertSerializedOptionsUnchanged();

// When streaming, stream_options should now be present

await foreach (StreamingChatCompletionUpdate update in client.CompleteChatStreamingAsync(["Hello, mock"], options))
{ }
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Contain("stream_options"));
AssertSerializedOptionsUnchanged();

// Going back to non-streaming, stream_options should again not be present

_ = await client.CompleteChatAsync(["Hello, mock"], options);
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Not.Contain("stream_options"));
AssertSerializedOptionsUnchanged();

// When data_sources are provided, stream_options should specially be omitted even when streaming

AzureSearchChatDataSource source = new()
{
Endpoint = new Uri("https://some-search-resource.azure.com"),
Authentication = DataSourceAuthentication.FromApiKey("test-api-key"),
IndexName = "index-name-here",
FieldMappings = new()
{
ContentFieldNames = { "hello" },
TitleFieldName = "hi",
},
AllowPartialResults = true,
QueryType = DataSourceQueryType.Simple,
OutputContexts = DataSourceOutputContexts.AllRetrievedDocuments | DataSourceOutputContexts.Citations,
VectorizationSource = DataSourceVectorizer.FromEndpoint(
new Uri("https://my-embedding.com"),
DataSourceAuthentication.FromApiKey("embedding-api-key")),
};
options.AddDataSource(source);
serializedOriginalOptions = ModelReaderWriter.Write(options).ToString();

await foreach (StreamingChatCompletionUpdate update in client.CompleteChatStreamingAsync(["Hello, mock"], options))
{ }
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Not.Contain("stream_options"));
AssertSerializedOptionsUnchanged();

// And the non-presence should of course also be true for non-streaming

_ = await client.CompleteChatAsync(["Hello, mock"], options);
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Not.Contain("stream_options"));
AssertSerializedOptionsUnchanged();

// Finally, with no/default options, streaming should have stream_options
await foreach (StreamingChatCompletionUpdate update in client.CompleteChatStreamingAsync(["Hello, mock"]))
{ }
newRequestEvent.WaitOne();
Assert.That(latestSerializedRequest, Does.Contain("stream_options"));
}

#if NET
[RecordedTest]
public async Task PredictedOutputsWork()
Expand Down Expand Up @@ -904,4 +936,21 @@ private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder

#endregion
}

private ChatClient GetMockChatClient(MockHttpMessageHandler mockHttpMessageHandler)
{
AzureOpenAIClient topLevel = new(
s_mockEndpoint,
new ApiKeyCredential(s_mockApiKeyValue),
new AzureOpenAIClientOptions()
{
Transport = mockHttpMessageHandler.Transport
});

return WrapClient(topLevel.GetChatClient(s_mockModelValue));
}

private static readonly Uri s_mockEndpoint = new("https://www.bing.com");
private static readonly string s_mockApiKeyValue = "not-a-real-key";
private static readonly string s_mockModelValue = "not-a-real-model";
}
31 changes: 13 additions & 18 deletions .dotnet/src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnume
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);
options = CreatePerCallOptions(options, messages);
using OpenTelemetryScope scope = _telemetry.StartChatScope(options);

try
Expand Down Expand Up @@ -133,8 +132,7 @@ public virtual ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);
options = CreatePerCallOptions(options, messages);
using OpenTelemetryScope scope = _telemetry.StartChatScope(options);

try
Expand Down Expand Up @@ -184,8 +182,8 @@ public virtual AsyncCollectionResult<StreamingChatCompletionUpdate> CompleteChat
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);
options = CreatePerCallOptions(options, messages, stream: true);
using OpenTelemetryScope scope = _telemetry.StartChatScope(options);

using BinaryContent content = options;

Expand All @@ -211,8 +209,8 @@ public virtual CollectionResult<StreamingChatCompletionUpdate> CompleteChatStrea
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);
options = CreatePerCallOptions(options, messages, stream: true);
using OpenTelemetryScope scope = _telemetry.StartChatScope(options);

using BinaryContent content = options;
ClientResult sendRequest() => CompleteChat(content, cancellationToken.ToRequestOptions(streaming: true));
Expand Down Expand Up @@ -247,20 +245,17 @@ public virtual AsyncCollectionResult<StreamingChatCompletionUpdate> CompleteChat
public virtual CollectionResult<StreamingChatCompletionUpdate> CompleteChatStreaming(params ChatMessage[] messages)
=> CompleteChatStreaming(messages, default(ChatCompletionOptions));

private void CreateChatCompletionOptions(IEnumerable<ChatMessage> messages, ref ChatCompletionOptions options, bool stream = false)
internal virtual ChatCompletionOptions CreatePerCallOptions(ChatCompletionOptions userOptions, IEnumerable<ChatMessage> messages, bool stream = false)
{
options.Messages = messages.ToList();
options.Model = _model;
ChatCompletionOptions copiedOptions = userOptions?.GetClone() ?? new();
copiedOptions.Messages = messages.ToList();
copiedOptions.Model = _model;
if (stream)
{
options.Stream = true;
options.StreamOptions = s_includeUsageStreamOptions;
}
else
{
options.Stream = null;
options.StreamOptions = null;
copiedOptions.Stream = true;
copiedOptions.StreamOptions = s_includeUsageStreamOptions;
}
return copiedOptions;
}

private static readonly InternalChatCompletionStreamOptions s_includeUsageStreamOptions
Expand Down
Loading