Skip to content

Commit aafa6a2

Browse files
Sanitize structured output schema type name to satisfy restrictions (#5504)
2 parents 221449d + ac55349 commit aafa6a2

File tree

4 files changed

+61
-28
lines changed

4 files changed

+61
-28
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using System.Text.Json.Nodes;
1515
using System.Text.Json.Schema;
1616
using System.Text.Json.Serialization;
17+
using System.Text.RegularExpressions;
1718
using Microsoft.Shared.Diagnostics;
1819

1920
using FunctionParameterKey = (System.Type? Type, string ParameterName, string? Description, bool HasDefaultValue, object? DefaultValue);
@@ -375,4 +376,20 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
375376
[JsonSerializable(typeof(JsonElement))]
376377
[JsonSerializable(typeof(JsonDocument))]
377378
private sealed partial class FunctionCallHelperContext : JsonSerializerContext;
379+
380+
/// <summary>
381+
/// Remove characters from method name that are valid in metadata but shouldn't be used in a method name.
382+
/// This is primarily intended to remove characters emitted by for compiler-generated method name mangling.
383+
/// </summary>
384+
public static string SanitizeMetadataName(string metadataName) =>
385+
InvalidNameCharsRegex().Replace(metadataName, "_");
386+
387+
/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
388+
#if NET
389+
[GeneratedRegex("[^0-9A-Za-z_]")]
390+
private static partial Regex InvalidNameCharsRegex();
391+
#else
392+
private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
393+
private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
394+
#endif
378395
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using System.Threading;
1515
using System.Threading.Tasks;
1616
using Microsoft.Shared.Diagnostics;
17+
using static Microsoft.Extensions.AI.FunctionCallHelpers;
1718

1819
namespace Microsoft.Extensions.AI;
1920

@@ -167,7 +168,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
167168
// the LLM backend is meant to do whatever's needed to explain the schema to the LLM.
168169
options.ResponseFormat = ChatResponseFormat.ForJsonSchema(
169170
schema,
170-
schemaName: typeof(T).Name,
171+
schemaName: SanitizeMetadataName(typeof(T).Name),
171172
schemaDescription: typeof(T).GetCustomAttribute<DescriptionAttribute>()?.Description);
172173
}
173174
else

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,16 @@
1212
using System.Text.Json;
1313
using System.Text.Json.Nodes;
1414
using System.Text.Json.Serialization.Metadata;
15-
using System.Text.RegularExpressions;
1615
using System.Threading;
1716
using System.Threading.Tasks;
1817
using Microsoft.Shared.Collections;
1918
using Microsoft.Shared.Diagnostics;
19+
using static Microsoft.Extensions.AI.FunctionCallHelpers;
2020

2121
namespace Microsoft.Extensions.AI;
2222

2323
/// <summary>Provides factory methods for creating commonly-used implementations of <see cref="AIFunction"/>.</summary>
24-
public static
25-
#if NET
26-
partial
27-
#endif
28-
class AIFunctionFactory
24+
public static class AIFunctionFactory
2925
{
3026
internal const string UsesReflectionJsonSerializerMessage =
3127
"This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications.";
@@ -107,11 +103,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
107103
return new ReflectionAIFunction(method, target, options);
108104
}
109105

110-
private sealed
111-
#if NET
112-
partial
113-
#endif
114-
class ReflectionAIFunction : AIFunction
106+
private sealed class ReflectionAIFunction : AIFunction
115107
{
116108
private readonly MethodInfo _method;
117109
private readonly object? _target;
@@ -474,21 +466,5 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
474466
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
475467
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
476468
}
477-
478-
/// <summary>
479-
/// Remove characters from method name that are valid in metadata but shouldn't be used in a method name.
480-
/// This is primarily intended to remove characters emitted by for compiler-generated method name mangling.
481-
/// </summary>
482-
private static string SanitizeMetadataName(string methodName) =>
483-
InvalidNameCharsRegex().Replace(methodName, "_");
484-
485-
/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
486-
#if NET
487-
[GeneratedRegex("[^0-9A-Za-z_]")]
488-
private static partial Regex InvalidNameCharsRegex();
489-
#else
490-
private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
491-
private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
492-
#endif
493469
}
494470
}

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,40 @@ public async Task CanUseNativeStructuredOutput()
172172
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
173173
}
174174

175+
[Fact]
176+
public async Task CanUseNativeStructuredOutputWithSanitizedTypeName()
177+
{
178+
var expectedResult = new Data<Animal> { Value = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } };
179+
var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]);
180+
181+
using var client = new TestChatClient
182+
{
183+
CompleteAsyncCallback = (messages, options, cancellationToken) =>
184+
{
185+
var responseFormat = Assert.IsType<ChatResponseFormatJson>(options!.ResponseFormat);
186+
187+
Assert.Matches("Data_1", responseFormat.SchemaName);
188+
189+
return Task.FromResult(expectedCompletion);
190+
},
191+
};
192+
193+
var chatHistory = new List<ChatMessage> { new(ChatRole.User, "Hello") };
194+
var response = await client.CompleteAsync<Data<Animal>>(chatHistory, useNativeJsonSchema: true);
195+
196+
// The completion contains the deserialized result and other completion properties
197+
Assert.Equal(1, response.Result!.Value!.Id);
198+
Assert.Equal("Tigger", response.Result.Value.FullName);
199+
Assert.Equal(Species.Tiger, response.Result.Value.Species);
200+
201+
// TryGetResult returns the same value
202+
Assert.True(response.TryGetResult(out var tryGetResultOutput));
203+
Assert.Same(response.Result, tryGetResultOutput);
204+
205+
// History remains unmutated
206+
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
207+
}
208+
175209
[Fact]
176210
public async Task CanSpecifyCustomJsonSerializationOptions()
177211
{
@@ -247,6 +281,11 @@ private class Animal
247281
public Species Species { get; set; }
248282
}
249283

284+
private class Data<T>
285+
{
286+
public T? Value { get; set; }
287+
}
288+
250289
private enum Species
251290
{
252291
Bear,

0 commit comments

Comments
 (0)