diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/OnnxGenAIChatCompletionsStreamingClass.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/OnnxGenAIChatCompletionsStreamingClass.cs index a7f3051f..4674e792 100644 --- a/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/OnnxGenAIChatCompletionsStreamingClass.cs +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/OnnxGenAIChatCompletionsStreamingClass.cs @@ -3,6 +3,12 @@ public class ContentMessage { + public ContentMessage() + { + Role = string.Empty; + Content = string.Empty; + } + public string Role { get; set; } public string Content { get; set; } } @@ -33,8 +39,8 @@ public string GetChatCompletionStreaming(string userPrompt, Action? call var responseContent = string.Empty; using var tokens = _tokenizer.Encode(string.Join("\n", _messages - .Select(m => $"<|{m.Role}|>{m.Content}<|end|>")) - + "<|assistant|>"); + .Select(m => $"<|{m.Role}|>\n{m.Content}\n<|end|>")) + + "<|assistant|>\n"); using var generatorParams = new GeneratorParams(_model); generatorParams.SetSearchOption("max_length", 2048); diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/_.json b/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/_.json index 4db76896..f426ef55 100644 --- a/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/_.json +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-cs/_.json @@ -3,5 +3,5 @@ "_ShortName": "phi3-onnx-chat-streaming", "_Language": "C#", - "ONNX_GENAI_MODEL_PLATFORM": "CPU" + "ONNX_GENAI_MODEL_PLATFORM": "DIRECTML" } \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/FunctionFactory.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/FunctionFactory.cs new file mode 100644 index 00000000..51a669b9 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/FunctionFactory.cs @@ -0,0 +1,434 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +using System.Reflection; +using System.Collections; +using System.Text; +using System.Text.Json; +using System.Collections.Generic; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class FunctionFactory +{ + public FunctionFactory() + { + } + + public FunctionFactory(Assembly assembly) + { + AddFunctions(assembly); + } + + public FunctionFactory(Type type1, params Type[] types) + { + AddFunctions(type1, types); + } + + public FunctionFactory(IEnumerable types) + { + AddFunctions(types); + } + + public FunctionFactory(Type type) + { + AddFunctions(type); + } + + public void AddFunctions(Assembly assembly) + { + AddFunctions(assembly.GetTypes()); + } + + public void AddFunctions(Type type1, params Type[] types) + { + AddFunctions(new List { type1 }); + AddFunctions(types); + } + + public void AddFunctions(IEnumerable types) + { + foreach (var type in types) + { + AddFunctions(type); + } + } + + public void AddFunctions(Type type) + { + var methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public); + foreach (var method in methods) + { + AddFunction(method); + } + } + + public void AddFunction(MethodInfo method) + { + var attributes = method.GetCustomAttributes(typeof(HelperFunctionDescriptionAttribute), false); + if (attributes.Length > 0) + { + var funcDescriptionAttrib = attributes[0] as HelperFunctionDescriptionAttribute; + var funcDescription = funcDescriptionAttrib!.Description; + + string json = GetMethodParametersJsonSchema(method); + var tool = new OnnxGenAIChatFunctionTool() { + Name = method.Name, + Description = funcDescription ?? $"The {method.Name} function", + Parameters = json + }; + _functions.TryAdd(method, tool); + } + } + + public IEnumerable GetChatTools() + { + return _functions.Values; + } + + public bool TryCallFunction(string functionName, string functionArguments, out string? result) + { + result = null; + if (!string.IsNullOrEmpty(functionName) && !string.IsNullOrEmpty(functionArguments)) + { + var function = _functions.FirstOrDefault(x => x.Value.Name == functionName); + if (function.Key != null) + { + result = CallFunction(function.Key, function.Value, functionArguments); + return true; + } + } + return false; + } + + // operator to add to FunctionFactories together + public static FunctionFactory operator +(FunctionFactory a, FunctionFactory b) + { + var newFactory = new FunctionFactory(); + a._functions.ToList().ForEach(x => newFactory._functions.Add(x.Key, x.Value)); + b._functions.ToList().ForEach(x => newFactory._functions.Add(x.Key, x.Value)); + return newFactory; + } + + private static string? CallFunction(MethodInfo methodInfo, OnnxGenAIChatFunctionTool chatTool, string argumentsAsJson) + { + var parsed = JsonDocument.Parse(argumentsAsJson).RootElement; + var arguments = new List(); + + var parameters = methodInfo.GetParameters(); + foreach (var parameter in parameters) + { + var parameterName = parameter.Name; + if (parameterName == null) continue; + + if (parsed.ValueKind == JsonValueKind.Object && parsed.TryGetProperty(parameterName, out var value)) + { + var parameterValue = value.ValueKind == JsonValueKind.String ? value.GetString() : value.GetRawText(); + if (parameterValue == null) continue; + + var argument = ParseParameterValue(parameterValue, parameter.ParameterType); + arguments.Add(argument); + } + } + + var args = arguments.ToArray(); + var result = CallFunction(methodInfo, args); + return ConvertFunctionResultToString(result); + } + + private static object? CallFunction(MethodInfo methodInfo, object[] args) + { + var t = methodInfo.ReturnType; + return t == typeof(Task) + ? CallVoidAsyncFunction(methodInfo, args) + : t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Task<>) + ? CallAsyncFunction(methodInfo, args) + : t.Name != "Void" + ? CallSyncFunction(methodInfo, args) + : CallVoidFunction(methodInfo, args); + } + + private static object? CallVoidAsyncFunction(MethodInfo methodInfo, object[] args) + { + var task = methodInfo.Invoke(null, args) as Task; + task!.Wait(); + return true; + } + + private static object? CallAsyncFunction(MethodInfo methodInfo, object[] args) + { + var task = methodInfo.Invoke(null, args) as Task; + task!.Wait(); + return task.GetType().GetProperty("Result")?.GetValue(task); + } + + private static object? CallSyncFunction(MethodInfo methodInfo, object[] args) + { + return methodInfo.Invoke(null, args); + } + + private static object? CallVoidFunction(MethodInfo methodInfo, object[] args) + { + methodInfo.Invoke(null, args); + return true; + } + + private static string? ConvertFunctionResultToString(object? result) + { + if (result is IEnumerable enumerable && !(result is string)) + { + using var stream = new MemoryStream(); + using var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false }); + writer.WriteStartArray(); + foreach (var item in enumerable) + { + var str = item.ToString(); + writer.WriteStringValue(str); + } + writer.WriteEndArray(); + writer.Flush(); + return Encoding.UTF8.GetString(stream.ToArray()); + } + return result?.ToString(); + } + + private static object ParseParameterValue(string parameterValue, Type parameterType) + { + if (IsArrayType(parameterType)) + { + Type elementType = parameterType.GetElementType()!; + return CreateGenericCollectionFromJsonArray(parameterValue, typeof(Array), elementType); + } + + if (IsTuppleType(parameterType)) + { + Type elementType = parameterType.GetGenericArguments()[0]; + return CreateTuppleTypeFromJsonArray(parameterValue, elementType); + } + + if (IsGenericListOrEquivalentType(parameterType)) + { + Type elementType = parameterType.GetGenericArguments()[0]; + return CreateGenericCollectionFromJsonArray(parameterValue, typeof(List<>), elementType); + } + + switch (Type.GetTypeCode(parameterType)) + { + case TypeCode.Boolean: return bool.Parse(parameterValue!); + case TypeCode.Byte: return byte.Parse(parameterValue!); + case TypeCode.Decimal: return decimal.Parse(parameterValue!); + case TypeCode.Double: return double.Parse(parameterValue!); + case TypeCode.Single: return float.Parse(parameterValue!); + case TypeCode.Int16: return short.Parse(parameterValue!); + case TypeCode.Int32: return int.Parse(parameterValue!); + case TypeCode.Int64: return long.Parse(parameterValue!); + case TypeCode.SByte: return sbyte.Parse(parameterValue!); + case TypeCode.UInt16: return ushort.Parse(parameterValue!); + case TypeCode.UInt32: return uint.Parse(parameterValue!); + case TypeCode.UInt64: return ulong.Parse(parameterValue!); + case TypeCode.String: return parameterValue!; + default: return Convert.ChangeType(parameterValue!, parameterType); + } + } + + private static object CreateGenericCollectionFromJsonArray(string parameterValue, Type collectionType, Type elementType) + { + var root = JsonDocument.Parse(parameterValue).RootElement; + var array = root.ValueKind == JsonValueKind.Array + ? root.EnumerateArray().ToArray() + : Array.Empty(); + + if (collectionType == typeof(Array)) + { + var collection = Array.CreateInstance(elementType, array.Length); + for (int i = 0; i < array.Length; i++) + { + var parsed = ParseParameterValue(array[i].GetRawText(), elementType); + if (parsed != null) collection.SetValue(parsed, i); + } + return collection; + } + else if (collectionType == typeof(List<>)) + { + var collection = Activator.CreateInstance(collectionType.MakeGenericType(elementType)); + var list = collection as IList; + foreach (var item in array) + { + var parsed = ParseParameterValue(item.GetRawText(), elementType); + if (parsed != null) list!.Add(parsed); + } + return collection!; + } + + return array; + } + + private static object CreateTuppleTypeFromJsonArray(string parameterValue, Type elementType) + { + var list = new List(); + + var root = JsonDocument.Parse(parameterValue).RootElement; + var array = root.ValueKind == JsonValueKind.Array + ? root.EnumerateArray().ToArray() + : Array.Empty(); + + foreach (var item in array) + { + var parsed = ParseParameterValue(item.GetRawText(), elementType); + if (parsed != null) list!.Add(parsed); + } + + var collection = list.Count() switch + { + 1 => Activator.CreateInstance(typeof(Tuple<>).MakeGenericType(elementType), list[0]), + 2 => Activator.CreateInstance(typeof(Tuple<,>).MakeGenericType(elementType, elementType), list[0], list[1]), + 3 => Activator.CreateInstance(typeof(Tuple<,,>).MakeGenericType(elementType, elementType, elementType), list[0], list[1], list[2]), + 4 => Activator.CreateInstance(typeof(Tuple<,,,>).MakeGenericType(elementType, elementType, elementType, elementType), list[0], list[1], list[2], list[3]), + 5 => Activator.CreateInstance(typeof(Tuple<,,,,>).MakeGenericType(elementType, elementType, elementType, elementType, elementType), list[0], list[1], list[2], list[3], list[4]), + 6 => Activator.CreateInstance(typeof(Tuple<,,,,,>).MakeGenericType(elementType, elementType, elementType, elementType, elementType, elementType), list[0], list[1], list[2], list[3], list[4], list[5]), + 7 => Activator.CreateInstance(typeof(Tuple<,,,,,,>).MakeGenericType(elementType, elementType, elementType, elementType, elementType, elementType, elementType), list[0], list[1], list[2], list[3], list[4], list[5], list[6]), + _ => throw new Exception("Tuples with more than 7 elements are not supported") + }; + return collection!; + } + + private static string GetMethodParametersJsonSchema(MethodInfo method) + { + using var stream = new MemoryStream(); + using var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false }); + writer.WriteStartObject(); + + var requiredParameters = new List(); + + writer.WriteString("type", "object"); + writer.WriteStartObject("properties"); + foreach (var parameter in method.GetParameters()) + { + if (parameter.Name == null) continue; + + if (!parameter.IsOptional) + { + requiredParameters.Add(parameter.Name); + } + + writer.WritePropertyName(parameter.Name); + WriteJsonSchemaForParameterWithDescription(writer, parameter); + } + writer.WriteEndObject(); + + writer.WriteStartArray("required"); + foreach (var requiredParameter in requiredParameters) + { + writer.WriteStringValue(requiredParameter); + } + writer.WriteEndArray(); + + writer.WriteEndObject(); + writer.Flush(); + + return Encoding.UTF8.GetString(stream.ToArray()); + } + + private static void WriteJsonSchemaForParameterWithDescription(Utf8JsonWriter writer, ParameterInfo parameter) + { + WriteJsonSchemaType(writer, parameter.ParameterType, GetParameterDescription(parameter)); + } + + private static string GetParameterDescription(ParameterInfo parameter) + { + var attributes = parameter.GetCustomAttributes(typeof(HelperFunctionParameterDescriptionAttribute), false); + var paramDescriptionAttrib = attributes.Length > 0 ? (attributes[0] as HelperFunctionParameterDescriptionAttribute) : null; + return paramDescriptionAttrib?.Description ?? $"The {parameter.Name} parameter"; + } + + private static void WriteJsonSchemaType(Utf8JsonWriter writer, Type t, string? parameterDescription = null) + { + if (IsJsonArrayEquivalentType(t)) + { + WriteJsonArraySchemaType(writer, t, parameterDescription); + } + else + { + WriteJsonPrimitiveSchemaType(writer, t, parameterDescription); + } + } + + private static void WriteJsonArraySchemaType(Utf8JsonWriter writer, Type containerType, string? parameterDescription = null) + { + writer.WriteStartObject(); + writer.WriteString("type", "array"); + + writer.WritePropertyName("items"); + WriteJsonArrayItemSchemaType(writer, containerType); + + if (!string.IsNullOrEmpty(parameterDescription)) + { + writer.WriteString("description", parameterDescription); + } + + writer.WriteEndObject(); + } + + private static void WriteJsonArrayItemSchemaType(Utf8JsonWriter writer, Type containerType) + { + WriteJsonSchemaType(writer, containerType.IsArray + ? containerType.GetElementType()! + : containerType.GetGenericArguments()[0]); + } + + private static void WriteJsonPrimitiveSchemaType(Utf8JsonWriter writer, Type primativeType, string? parameterDescription = null) + { + writer.WriteStartObject(); + writer.WriteString("type", GetJsonTypeFromPrimitiveType(primativeType)); + + if (!string.IsNullOrEmpty(parameterDescription)) + { + writer.WriteString("description", parameterDescription); + } + + writer.WriteEndObject(); + } + + private static string GetJsonTypeFromPrimitiveType(Type primativeType) + { + return Type.GetTypeCode(primativeType) switch + { + TypeCode.Boolean => "boolean", + TypeCode.Byte or TypeCode.SByte or TypeCode.Int16 or TypeCode.Int32 or TypeCode.Int64 or + TypeCode.UInt16 or TypeCode.UInt32 or TypeCode.UInt64 => "integer", + TypeCode.Decimal or TypeCode.Double or TypeCode.Single => "number", + TypeCode.String => "string", + _ => "string" + }; + } + + private static bool IsJsonArrayEquivalentType(Type t) + { + return IsArrayType(t) || IsTuppleType(t) || IsGenericListOrEquivalentType(t); + } + + private static bool IsArrayType(Type t) + { + return t.IsArray; + } + + private static bool IsTuppleType(Type parameterType) + { + return parameterType.IsGenericType && parameterType.GetGenericTypeDefinition().Name.StartsWith("Tuple"); + } + + private static bool IsGenericListOrEquivalentType(Type t) + { + return t.IsGenericType && + (t.GetGenericTypeDefinition() == typeof(List<>) || + t.GetGenericTypeDefinition() == typeof(ICollection<>) || + t.GetGenericTypeDefinition() == typeof(IEnumerable<>) || + t.GetGenericTypeDefinition() == typeof(IList<>) || + t.GetGenericTypeDefinition() == typeof(IReadOnlyCollection<>) || + t.GetGenericTypeDefinition() == typeof(IReadOnlyList<>)); + } + + private Dictionary _functions = new(); +} diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionDescriptionAttribute.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionDescriptionAttribute.cs new file mode 100644 index 00000000..82182e2b --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionDescriptionAttribute.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +public class HelperFunctionDescriptionAttribute : Attribute +{ + public HelperFunctionDescriptionAttribute() + { + } + + public HelperFunctionDescriptionAttribute(string description) + { + Description = description; + } + + public string? Description { get; set; } +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionParameterDescriptionAttribute.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionParameterDescriptionAttribute.cs new file mode 100644 index 00000000..00c90481 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/HelperFunctionParameterDescriptionAttribute.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +public class HelperFunctionParameterDescriptionAttribute : Attribute +{ + public HelperFunctionParameterDescriptionAttribute() + { + } + + public HelperFunctionParameterDescriptionAttribute(string? description = null) + { + Description = description; + } + + public string? Description { get; set; } +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsCustomFunctions.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsCustomFunctions.cs new file mode 100644 index 00000000..eaa5234e --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsCustomFunctions.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +using System; + +public class OnnxGenAIChatCompletionsCustomFunctions +{ + [HelperFunctionDescription("Gets the current weather for a location.")] + public static string GetCurrentWeather(string location) + { + return $"The weather in {location} is 72 degrees and sunny."; + } + + [HelperFunctionDescription("Gets the current date.")] + public static string GetCurrentDate() + { + var date = DateTime.Now; + return $"{date.Year}-{date.Month}-{date.Day}"; + } + + [HelperFunctionDescription("Gets the current time.")] + public static string GetCurrentTime() + { + var date = DateTime.Now; + return $"{date.Hour}:{date.Minute}:{date.Second}"; + } + + [HelperFunctionDescription("Sends a text message to a contact or a valid phone number.")] + public static bool SendTextMessage(string nameOrPhone, string message) + { + return true; + } +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsStreamingClass.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsStreamingClass.cs new file mode 100644 index 00000000..cb078bd6 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatCompletionsStreamingClass.cs @@ -0,0 +1,130 @@ +using System.Text; +using Microsoft.ML.OnnxRuntimeGenAI; +using System.Text.Json; + +public class OnnxGenAIChatStreamingClass +{ + public OnnxGenAIChatStreamingClass(string modelDirectory, string systemPrompt, FunctionFactory factory) + { + systemPrompt = UpdateSystemPrompt(systemPrompt, factory); + + _modelDirectory = modelDirectory; + _systemPrompt = systemPrompt; + _factory = factory; + + _messages = new List(); + _messages.Add(new OnnxGenAIChatContentMessage { Role = "system", Content = _systemPrompt }); + + _model = new Model(_modelDirectory); + _tokenizer = new Tokenizer(_model); + + _functionCallContext = new OnnxGenAIChatFunctionCallContext(_factory, _messages); + } + + public void ClearMessages() + { + _messages.Clear(); + _messages.Add(new OnnxGenAIChatContentMessage { Role = "system", Content = _systemPrompt }); + } + + public string GetChatCompletionStreaming(string userPrompt, Action? callback = null) + { + var debug = Environment.GetEnvironmentVariable("DEBUG") != null; + + _messages.Add(new OnnxGenAIChatContentMessage { Role = "user", Content = userPrompt }); + + var responseContent = string.Empty; + while (true) + { + var history = string.Join("\n", _messages + .Select(m => $"<|{m.Role}|>\n{m.Content}\n<|end|>")) + + "<|assistant|>\n"; + + // Console.WriteLine("\n**************** History ****************"); + // Console.WriteLine(history); + // Console.WriteLine("----------------------------------------\n"); + + using var tokens = _tokenizer.Encode(history); + + using var generatorParams = new GeneratorParams(_model); + generatorParams.SetSearchOption("max_length", 2048); + generatorParams.SetInputSequences(tokens); + + using var generator = new Generator(_model, generatorParams); + + var sb = new StringBuilder(); + while (!generator.IsDone()) + { + generator.ComputeLogits(); + generator.GenerateNextToken(); + + var outputTokens = generator.GetSequence(0); + var newToken = outputTokens.Slice(outputTokens.Length - 1, 1); + + var startAnswerAt = sb.ToString().LastIndexOf("<|answer|>"); + var endAnswerAt = sb.ToString().LastIndexOf("<|end_answer|>"); + var insideAnswer = startAnswerAt >= 0 && startAnswerAt > endAnswerAt; + + var output = _tokenizer.Decode(newToken); + sb.Append(output); + + if (insideAnswer || debug) callback?.Invoke(output); + + if (sb.ToString().Contains("<|end_answer|>")) break; + if (_functionCallContext.CheckForFunctions(sb)) break; + } + + if (_functionCallContext.TryCallFunctions(sb)) + { + _functionCallContext.Clear(); + continue; + } + + responseContent = sb.ToString(); + var ok = !string.IsNullOrWhiteSpace(responseContent); + if (ok) + { + _messages.Add(new OnnxGenAIChatContentMessage { Role = "assistant", Content = responseContent }); + } + + return responseContent; + } + } + + private static string UpdateSystemPrompt(string systemPrompt, FunctionFactory factory) + { + var functionsSchemaStartsAt = systemPrompt.IndexOf("<|functions_schema|>"); + if (functionsSchemaStartsAt >= 0) + { + var functionsSchemaEndsAt = systemPrompt.IndexOf("<|end_functions_schema|>", functionsSchemaStartsAt); + if (functionsSchemaEndsAt >= 0) + { + var asYaml = new StringBuilder(); + var tools = factory.GetChatTools().ToList(); + foreach (var tool in tools) + { + asYaml.Append($"- name: {tool.Name}\n"); + asYaml.Append($" description: {tool.Description}\n"); + asYaml.Append($" parameters: |\n"); + asYaml.Append($" {tool.Parameters}\n"); + } + + systemPrompt = systemPrompt.Remove(functionsSchemaStartsAt, functionsSchemaEndsAt - functionsSchemaStartsAt + "<|end_functions_schema|>".Length); + + var newFunctionsSchema = "<|functions_schema|>\n" + asYaml + "\n<|end_functions_schema|>"; + systemPrompt = systemPrompt.Insert(functionsSchemaStartsAt, newFunctionsSchema); + } + } + + return systemPrompt; + } + + private string _modelDirectory; + private string _systemPrompt; + private FunctionFactory _factory; + + private Model _model; + private Tokenizer _tokenizer; + private List _messages; + private OnnxGenAIChatFunctionCallContext _functionCallContext; +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatContentMessage.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatContentMessage.cs new file mode 100644 index 00000000..8f70a32d --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatContentMessage.cs @@ -0,0 +1,11 @@ +public class OnnxGenAIChatContentMessage +{ + public OnnxGenAIChatContentMessage() + { + Role = string.Empty; + Content = string.Empty; + } + + public string Role { get; set; } + public string Content { get; set; } +} diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatFunctionCallContext.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatFunctionCallContext.cs new file mode 100644 index 00000000..f2881ba9 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatFunctionCallContext.cs @@ -0,0 +1,179 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; + +public class OnnxGenAIChatFunctionCallContext +{ + public OnnxGenAIChatFunctionCallContext(FunctionFactory functionFactory, IList messages) + { + _functionFactory = functionFactory; + _messages = messages; + } + + public bool CheckForFunctions(StringBuilder content) + { + _functionCallsTagStartsAt = content.ToString().IndexOf("<|function_calls|>"); + if (_functionCallsTagStartsAt < 0) return false; + + _endFunctionCallsTagStartsAt = content.ToString().IndexOf("<|end_function_calls|>", _functionCallsTagStartsAt); + return _functionCallsTagStartsAt >= 0 && _endFunctionCallsTagStartsAt >= 0; + } + + public bool TryCallFunctions(StringBuilder content) + { + if (!CheckForFunctions(content)) return false; + if (!ExtractFunctions(content)) return false; + + var contentBeforeFunctionCallsTag = content.ToString().Substring(0, _functionCallsTagStartsAt); + if (!string.IsNullOrWhiteSpace(contentBeforeFunctionCallsTag)) + { + _messages.Add(new OnnxGenAIChatContentMessage { Role = "assistant", Content = contentBeforeFunctionCallsTag }); + } + + var contentIncludingEndFunctionCallsTag = content.ToString().Substring(0, _endFunctionCallsTagStartsAt + "<|end_function_calls|>".Length); + content.Remove(0, contentIncludingEndFunctionCallsTag.Length); + + var contentInsideFunctionCallsTag = contentIncludingEndFunctionCallsTag.Substring(_functionCallsTagStartsAt, _endFunctionCallsTagStartsAt - _functionCallsTagStartsAt + "<|end_function_calls|>".Length); + if (!string.IsNullOrWhiteSpace(contentInsideFunctionCallsTag)) + { + _messages.Add(new OnnxGenAIChatContentMessage { Role = "assistant", Content = contentInsideFunctionCallsTag }); + } + + var hasPlaceholders = _indexToArguments.Any(x => x.Value.Any(y => y.Value.Contains("PLACEHOLDER"))); + if (hasPlaceholders) + { + _messages.Add(new OnnxGenAIChatContentMessage { Role = "assistant", Content = "Oh, wait! I can't use placeholders in function calls. use answer and end_answer to ask the user for the missing information." }); + return true; + } + + var results = new StringBuilder(); + for (var index = 0; index < _indexToFunctionName.Count; index++) + { + var functionName = _indexToFunctionName[index]; + var functionArguments = _indexToArguments[index]; + var asJson = JsonFromArguments(functionArguments); + + var result = TryCatchCallFunction(functionName, asJson); + Console.WriteLine($"\rassistant-function: {functionName}({asJson}) => {result}"); + + results.AppendLine($"- api: {functionName}"); + results.AppendLine($" result: {result}"); + results.AppendLine(); + } + + Console.Write("\nAssistant: "); + + var functionCallResults = results.ToString().Trim(' ', '\n', '\r'); + functionCallResults = $"\n{functionCallResults}\n<|end_function_call_results|>\n"; + _messages.Add(new OnnxGenAIChatContentMessage { Role = "function_call_results", Content = functionCallResults }); + + return true; + } + + public void Clear() + { + _indexToFunctionName.Clear(); + _indexToArguments.Clear(); + _functionCallsTagStartsAt = -1; + _endFunctionCallsTagStartsAt = -1; + } + + private bool ExtractFunctions(StringBuilder content) + { + var lines = content.ToString() + .Substring(_functionCallsTagStartsAt, _endFunctionCallsTagStartsAt - _functionCallsTagStartsAt) + .Split('\n', StringSplitOptions.RemoveEmptyEntries); + + var currentFunctionIndex = -1; + for (var i = 0; i < lines.Length; i++) + { + var line = lines[i].Trim(' ', '\r', '\n', '-'); + if (string.IsNullOrEmpty(line)) continue; + + if (line.StartsWith("api:")) + { + var api = line.Substring("api:".Length).Trim(); + _indexToFunctionName[++currentFunctionIndex] = api; + _indexToArguments[currentFunctionIndex] = new List>(); + continue; + } + + if (currentFunctionIndex >= 0) + { + var colonIndex = line.IndexOf(':'); + if (colonIndex < 0) continue; + + var key = line.Substring(0, colonIndex).Trim(); + var value = line.Substring(colonIndex + 1).Trim(); + _indexToArguments[currentFunctionIndex].Add(new KeyValuePair(key, value)); + } + } + + return _indexToFunctionName.Count > 0; + } + + private string JsonFromArguments(List> functionArguments) + { + var tryExpandParametersJson = functionArguments.Any(x => x.Key == "parameters" && !string.IsNullOrEmpty(x.Value)); + if (tryExpandParametersJson) + { + var parametersSpecifiedAsJson = functionArguments.First(x => x.Key == "parameters"); + functionArguments.Remove(parametersSpecifiedAsJson); + var parameters = JsonSerializer.Deserialize>(parametersSpecifiedAsJson.Value); + if (parameters != null) + { + foreach (var parameter in parameters) + { + functionArguments.Add(new KeyValuePair(parameter.Key, parameter.Value)); + } + } + } + + using var stream = new MemoryStream(); + using var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false }); + writer.WriteStartObject(); + + foreach (var argument in functionArguments) + { + writer.WritePropertyName(argument.Key); + writer.WriteStringValue(argument.Value); + } + + writer.WriteEndObject(); + writer.Flush(); + + return Encoding.UTF8.GetString(stream.ToArray()); + } + + private string? TryCatchCallFunction(string functionName, string asJson) + { + string? result; + try + { + var ok = _functionFactory.TryCallFunction(functionName, asJson, out result); + if (!ok) result = $"Function '{functionName}' not found."; + } + catch (Exception ex) + { + result = $"Error calling function '{functionName}': {ex.Message}"; + } + + return result; + } + + + private FunctionFactory _functionFactory; + private IList _messages; + + private int _functionCallsTagStartsAt = -1; + private int _endFunctionCallsTagStartsAt = -1; + + private Dictionary _indexToFunctionName = []; + private Dictionary>> _indexToArguments = []; +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatTool.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatTool.cs new file mode 100644 index 00000000..951d3ebd --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/OnnxGenAIChatTool.cs @@ -0,0 +1,11 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +public class OnnxGenAIChatFunctionTool +{ + public required string Name { get; init; } + public required string Description { get; init; } + public required string Parameters { get; init; } +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Phi3ChatStreaming.csproj._ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Phi3ChatStreaming.csproj._ new file mode 100644 index 00000000..b5f1d2c9 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Phi3ChatStreaming.csproj._ @@ -0,0 +1,17 @@ + + + Exe + net8.0 + enable + enable + + + {{if contains(toupper("{ONNX_GENAI_MODEL_PLATFORM}"), "DIRECTML")}} + + {{else if contains(toupper("{ONNX_GENAI_MODEL_PLATFORM}"), "CUDA")}} + + {{else}} + + {{endif}} + + diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Program.cs b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Program.cs new file mode 100644 index 00000000..416815ec --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/Program.cs @@ -0,0 +1,44 @@ +public class Program +{ + public static void Main(string[] args) + { + var modelDirectory = Environment.GetEnvironmentVariable("ONNX_GENAI_MODEL_PATH") ?? ""; + var systemPrompt = Environment.GetEnvironmentVariable("ONNX_GENAI_SYSTEM_PROMPT") ?? "@system.txt"; + + if (string.IsNullOrEmpty(modelDirectory) || modelDirectory.StartsWith(" { + Console.Write(update); + }); + Console.WriteLine("\n"); + } + } +} diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/_.json b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/_.json new file mode 100644 index 00000000..e2195763 --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/_.json @@ -0,0 +1,7 @@ +{ + "_LongName": "Phi-3 Chat Completions (w/ ONNX + Functions)", + "_ShortName": "phi3-onnx-chat-streaming-with-functions", + "_Language": "C#", + + "ONNX_GENAI_MODEL_PLATFORM": "DIRECTML" +} \ No newline at end of file diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/get-phi3-mini-onnx.cmd b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/get-phi3-mini-onnx.cmd new file mode 100644 index 00000000..ddc57b1d --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/get-phi3-mini-onnx.cmd @@ -0,0 +1,3 @@ +git lfs install +git clone https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx +git lfs checkout diff --git a/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/system.txt b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/system.txt new file mode 100644 index 00000000..4c50d8af --- /dev/null +++ b/src/ai/.x/templates/phi3-onnx-chat-streaming-with-functions-cs/system.txt @@ -0,0 +1,134 @@ +You are a diligently helpful AI assistant. + +Through the use of APIs, you provide answers to queries, always exemplifying rigorous logic and intelligence. API names must be called out only when required to resolve a query. Each API call expects complete fulfillment of required parameters - invest in discussions with the user or utilize provided functions to get the necessary information. + +Upon receiving a query, condense all required function calls into one location, using the beginning and ending tags: <|function_calls|> and <|end_function_calls|>. Multiple function calls should have a newline separation. + +Concisely note your thoughts and reasoning relating to the use or non-use of functions within the <|thoughts|> and <|end_thoughts|> tags, positioned before the function calls like: + +<|thoughts|> If the user seeks X, I'll need to: + +1. Call function_name for the following reason, +2. Invoke function_name_2 benefiting from this, +3. Continue so on, <|end_thoughts|> + +Responses should limit to the content within the <|function_calls|> & <|end_function_calls|>, <|answer|> & <|end_answer|> and <|thoughts|> & <|end_thoughts|> constructs. + +Use the template examples below to guide your responses. Note PLACEHOLDER1, PLACEHOLDER2, and PLACEHOLDER3 are used as temporary values and should be replaced with actual data. + + <|function_call_examples|> + - <|user|> + hey + <|end> + <|assistant|> + <|thoughts|> + No function calls needed. I'll respond directly. + <|end_thoughts|> + <|answer|> + Hello! How can I help you today? + <|end_answer|> + <|end|> + + - <|user|> + who was president in 1999 + <|end|> + <|assistant|> + <|thoughts|> + I know the answer to this. No need for function calls. + <|end_thoughts|> + <|answer|> + Bill Clinton was the president in 1999. + <|end_answer|> + <|end|> + + - <|user|> + what's the date and the time? + <|end|> + <|assistant|> + <|thoughts|> + I can call both functions in one block. + - Call GetCurrentDate to get the current date. + - Call GetCurrentTime to get the current time. + <|end_thoughts|> + <|function_calls|> + - api: GetCurrentDate + parameters: {} + - api: GetCurrentTime + parameters: {} + <|end_function_calls|> + <|end|> + <|function_call_results|> + - api: GetCurrentDate + result: PLACEHOLDER1 + - api: GetCurrentTime + result: PLACEHOLDER2 + <|end_function_call_results|> + <|answer|>The current date is PLACEHOLDER1 and the current time is PLACEHOLDER2. + <|end_answer|> + <|end|> + + - <|user|>send the weather to PLACEHOLDER1 + <|end|> + <|assistant|> + <|thoughts|>I MUST ask the user for the required location. + <|end_thoughts|> + <|answer|>What location? + <|end_answer|> + <|end|> + <|user|>PLACEHOLDER2 + <|end|> + <|thoughts|>I have the location. Now, I'll: + 1. Call GetCurrentWeather for the location. + 2. Then, call SendTextMessage to send the weather from the previous function. + <|end_thoughts|> + <|function_calls|> + - api: GetCurrentWeather + parameters: { "location": "PLACEHOLDER2" } + <|end_function_calls|> + <|function_call_results|> + - api: GetCurrentWeather + result: It's currently PLACEHOLDER3. + <|end_function_call_results|> + <|function_calls|> + - api: SendTextMessage + parameters: { "nameOrPhone": "PLACEHOLDER1", "message": "The weather in PLACEHOLDER2 is PLACEHOLDER3." } + <|end_function_calls|> + <|function_call_results|> + - api: SendTextMessage + result: true + <|end_function_call_results|> + <|answer|>I sent the weather in PLACEHOLDER2 to PLACEHOLDER1. + <|end_answer|> + <|end|> + + <|end_function_call_examples|> + +The aforementioned functions were examples and are not accessible. + +Functions available for actual use are as follows: + +<|functions_schema|> +- name: GetCurrentWeather + description: Get the current weather in a given location + parameters: | + {"type":"object","properties":{"location":{"type":"string","description":"The location parameter"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}},"required":["location"]} + +- name: GetCurrentDate + description: Get the current date + +- name: GetCurrentTime + description: Get the current time +<|end_functions_schema|> + +Keep the following in mind: +* Respond without calling functions when possible. +* Calls made to function should be invisible to the user. +* Avoid unnecessary function calls. +* Rely strictly on the provided APIs. + +Critically, you **MUST** ask the user for information when not provided. +You **MUST NOT** use PLACEHOLDERS as parameter inputs or in responses. + +You must always start your response with <|thoughts|>, express your thoughts and reasoning, and end it with <|end_thoughts|>. + + diff --git a/tests/test.yaml b/tests/test.yaml index 4dbb5bce..c69fb172 100644 --- a/tests/test.yaml +++ b/tests/test.yaml @@ -104,6 +104,7 @@ ^OpenAI +Chat +Webpage +\(w/ +Speech +input/output\) +openai-chat-webpage-with-speech +TypeScript *\r?$\n ^OpenAI +Chat +Webpage +\(w/ +Functions +\+ +Speech\) +openai-chat-webpage-with-speech-and-functions +TypeScript *\r?$\n ^Phi-3 Chat Completions \(w/ ONNX\) +phi3-onnx-chat-streaming +C# +\r?$\n + ^Phi-3 Chat Completions \(w/ ONNX \+ Functions\) +phi3-onnx-chat-streaming-with-functions +C# +\r?$\n ^Semantic +Kernel +Chat +Completions +\(Streaming\) +sk-chat-streaming +C# *\r?$\n ^Semantic +Kernel +Chat +Completions +\(w/ +Data +\+ +AI +Search\) +sk-chat-streaming-with-data +C# *\r?$\n ^Semantic +Kernel +Chat +Completions +\(w/ +Functions\) +sk-chat-streaming-with-functions +C# *\r?$\n @@ -114,4 +115,4 @@ ^Speech-to-text +\(w/ +Keyword +detection\) +speech-to-text-with-keyword +C#, +Python *\r?$\n ^Speech-to-text +\(w/ +Translation\) +speech-to-text-with-translation +C#, +Python *\r?$\n ^Text-to-speech +text-to-speech +C#, +Python *\r?$\n - ^Text-to-speech +\(w/ +File +output\) +text-to-speech-with-file +C#, +Python *\r?$\n \ No newline at end of file + ^Text-to-speech +\(w/ +File +output\) +text-to-speech-with-file +C#, +Python *\r?$\n