Skip to content

Commit

Permalink
Sanitize structured output schema type name to satisfy restrictions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
eiriktsarpalis authored Oct 10, 2024
2 parents 221449d + ac55349 commit aafa6a2
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
using Microsoft.Shared.Diagnostics;

using FunctionParameterKey = (System.Type? Type, string ParameterName, string? Description, bool HasDefaultValue, object? DefaultValue);
Expand Down Expand Up @@ -375,4 +376,20 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(JsonDocument))]
private sealed partial class FunctionCallHelperContext : JsonSerializerContext;

/// <summary>
/// Remove characters from method name that are valid in metadata but shouldn't be used in a method name.
/// This is primarily intended to remove characters emitted by for compiler-generated method name mangling.
/// </summary>
public static string SanitizeMetadataName(string metadataName) =>
InvalidNameCharsRegex().Replace(metadataName, "_");

/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
#if NET
[GeneratedRegex("[^0-9A-Za-z_]")]
private static partial Regex InvalidNameCharsRegex();
#else
private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
using static Microsoft.Extensions.AI.FunctionCallHelpers;

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -167,7 +168,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
// the LLM backend is meant to do whatever's needed to explain the schema to the LLM.
options.ResponseFormat = ChatResponseFormat.ForJsonSchema(
schema,
schemaName: typeof(T).Name,
schemaName: SanitizeMetadataName(typeof(T).Name),
schemaDescription: typeof(T).GetCustomAttribute<DescriptionAttribute>()?.Description);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization.Metadata;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
using static Microsoft.Extensions.AI.FunctionCallHelpers;

namespace Microsoft.Extensions.AI;

/// <summary>Provides factory methods for creating commonly-used implementations of <see cref="AIFunction"/>.</summary>
public static
#if NET
partial
#endif
class AIFunctionFactory
public static class AIFunctionFactory
{
internal const string UsesReflectionJsonSerializerMessage =
"This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications.";
Expand Down Expand Up @@ -107,11 +103,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
return new ReflectionAIFunction(method, target, options);
}

private sealed
#if NET
partial
#endif
class ReflectionAIFunction : AIFunction
private sealed class ReflectionAIFunction : AIFunction
{
private readonly MethodInfo _method;
private readonly object? _target;
Expand Down Expand Up @@ -474,21 +466,5 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
}

/// <summary>
/// Remove characters from method name that are valid in metadata but shouldn't be used in a method name.
/// This is primarily intended to remove characters emitted by for compiler-generated method name mangling.
/// </summary>
private static string SanitizeMetadataName(string methodName) =>
InvalidNameCharsRegex().Replace(methodName, "_");

/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
#if NET
[GeneratedRegex("[^0-9A-Za-z_]")]
private static partial Regex InvalidNameCharsRegex();
#else
private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ public async Task CanUseNativeStructuredOutput()
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanUseNativeStructuredOutputWithSanitizedTypeName()
{
var expectedResult = new Data<Animal> { Value = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } };
var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]);

using var client = new TestChatClient
{
CompleteAsyncCallback = (messages, options, cancellationToken) =>
{
var responseFormat = Assert.IsType<ChatResponseFormatJson>(options!.ResponseFormat);
Assert.Matches("Data_1", responseFormat.SchemaName);
return Task.FromResult(expectedCompletion);
},
};

var chatHistory = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var response = await client.CompleteAsync<Data<Animal>>(chatHistory, useNativeJsonSchema: true);

// The completion contains the deserialized result and other completion properties
Assert.Equal(1, response.Result!.Value!.Id);
Assert.Equal("Tigger", response.Result.Value.FullName);
Assert.Equal(Species.Tiger, response.Result.Value.Species);

// TryGetResult returns the same value
Assert.True(response.TryGetResult(out var tryGetResultOutput));
Assert.Same(response.Result, tryGetResultOutput);

// History remains unmutated
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanSpecifyCustomJsonSerializationOptions()
{
Expand Down Expand Up @@ -247,6 +281,11 @@ private class Animal
public Species Species { get; set; }
}

private class Data<T>
{
public T? Value { get; set; }
}

private enum Species
{
Bear,
Expand Down

0 comments on commit aafa6a2

Please sign in to comment.