diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index c487d41c..2579a503 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -35,6 +35,7 @@ public class ExampleRunner { "Batched Executor: LLava", BatchedExecutorLLava.Run }, { "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run }, { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run }, + { "Custom Sampling Pipeline", CustomSampler.Run }, { "Speech Chat: Integration with Whisper.net", SpeechChat.Run }, { "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } } }; diff --git a/LLama.Examples/Examples/CustomSampler.cs b/LLama.Examples/Examples/CustomSampler.cs new file mode 100644 index 00000000..d2df8db2 --- /dev/null +++ b/LLama.Examples/Examples/CustomSampler.cs @@ -0,0 +1,114 @@ +using LLama.Common; +using LLama.Examples.Extensions; +using LLama.Native; +using LLama.Sampling; + +namespace LLama.Examples.Examples +{ + public class CustomSampler + { + public static async Task Run() + { + var modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); + + var ex = new StatelessExecutor(model, parameters); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("In this example a custom sampling pipeline with a custom sampler stage is being used. This demonstrates how to customise the samplers used, and " + + "how to create a completely custom sampler stage which modifies the logits or selects a token." + + "" + + "In this case the custom sampler stage removes the most likely token. This will probably produce bad results, it's just a demo!" + ); + Console.ForegroundColor = ConsoleColor.White; + + var inferenceParams = new InferenceParams + { + SamplingPipeline = new CustomSamplingPipeline(), + MaxTokens = 50 + }; + + while (true) + { + Console.Write("\nQuestion: "); + Console.ForegroundColor = ConsoleColor.Green; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Answer: "); + prompt = $"Question: {prompt?.Trim()} Answer: "; + await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner()) + { + Console.Write(text); + } + } + } + } + + public class CustomSamplingPipeline + : BaseSamplingPipeline + { + protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context) + { + var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + + // Take only the 10 most likely tokens + chain.AddTopK(10); + + // Remove the most likely token + chain.AddCustom(new RemoveMostLikelyToken()); + + // Select from the distribution + chain.AddSoftmax(); + chain.AddDistributionSampler(42); + + return chain; + } + } + + public class RemoveMostLikelyToken + : ICustomSampler + { + public string Name => "Remove Most Likely Token"; + + public void Apply(ref LLamaTokenDataArrayNative tokenData) + { + // Doesn't make sense to run this stage if there is only one candidate left + if (tokenData.Size <= 1) + return; + + // Ensure token data is sorted, so most likely token is first. + // Note that this is a descending sort, the **largest** value is first. + if (!tokenData.Sorted) + tokenData.Data.Sort((a, b) => b.Logit.CompareTo(a.Logit)); + + // Make the most likely token impossible to pick + tokenData.Data[0].Logit = float.NegativeInfinity; + + // It's **critically** important to set this if the logits are no longer sorted after the custom + // sampler has run. If you're not sure, it's always safer to set it to false. + // + // In this case, because the first logit has just been set to negative infinity + // the token data is definitely not sorted! + tokenData.Sorted = false; + } + + public void Accept(LLamaToken token) + { + } + + public void Reset() + { + } + + public ICustomSampler Clone() + { + return new RemoveMostLikelyToken(); + } + + public void Dispose() + { + } + } +} diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index d101a833..d6ab139e 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -149,7 +149,7 @@ public struct LLamaTokenDataArrayNative /// /// Number of LLamaTokenData in the array /// - public ulong size; + private ulong _size; /// /// The index in the array (i.e. not the token id) @@ -167,13 +167,13 @@ public Span Data { unsafe { - return new Span(_data, checked((int)size)); + return new Span(_data, checked((int)Size)); } } } /// - /// Indicates if the items in the array are sorted + /// Indicates if the items in the array are sorted, so the most likely token is first /// public bool Sorted { @@ -190,6 +190,20 @@ public long Selected set => _selected = value; } + /// + /// Number of LLamaTokenData in the array. Set this to shrink the array + /// + public ulong Size + { + get => _size; + set + { + if (value > _size) + throw new ArgumentOutOfRangeException(nameof(value), "Cannot set Size property to a larger value"); + _size = value; + } + } + /// /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray /// @@ -205,7 +219,7 @@ public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataA native = new LLamaTokenDataArrayNative { _data = (LLamaTokenData*)handle.Pointer, - size = (ulong)array.Data.Length, + Size = (ulong)array.Data.Length, Sorted = array.Sorted }; } diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs index 9cde52e4..902e34f2 100644 --- a/LLama/Native/SafeLLamaSamplerHandle.cs +++ b/LLama/Native/SafeLLamaSamplerHandle.cs @@ -162,7 +162,7 @@ public static SafeLLamaSamplerChainHandle Create(LLamaSamplerChainParams @params /// The index of the stage to clone public void AddClone(SafeLLamaSamplerChainHandle src, int index) { - if (index < 0 || index >= Count) + if (index < 0 || index >= src.Count) throw new ArgumentOutOfRangeException(nameof(index)); llama_sampler_chain_add( @@ -193,6 +193,22 @@ public void Remove(int index) } + /// + /// Add a custom sampler stage + /// + /// + /// + public void AddCustom(TSampler sampler) + where TSampler : class, ICustomSampler + { + unsafe + { + var samplerHandle = CustomSamplerHandle.Create(sampler); + llama_sampler_chain_add(this, (IntPtr)samplerHandle.GetLLamaSamplerPointer()); + } + } + + /// /// Add a sampler which picks the most likely token. /// @@ -502,70 +518,213 @@ public record struct LLamaLogitBias public float Bias; } -/* todo: Custom sampler stuff /// /// /// /// llama_sampler_i -public struct LLamaSamplerINative +[StructLayout(LayoutKind.Sequential)] +internal struct LLamaSamplerINative { - // Delegate definitions for the function pointers + /// + /// Get the name of this sampler + /// + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate string NameDelegate(IntPtr smpl); + public delegate string NameDelegate(ref LLamaSamplerNative smpl); + /// + /// Update internal sampler state after a token has been chosen + /// + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate void AcceptDelegate(IntPtr smpl, LLamaToken token); + public delegate void AcceptDelegate(ref LLamaSamplerNative smpl, LLamaToken token); + /// + /// Apply this sampler to a set of logits + /// + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate void ApplyDelegate(IntPtr smpl, ref LLamaTokenDataArrayNative cur_p); + public delegate void ApplyDelegate(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative cur_p); + /// + /// Reset the internal state of this sampler + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate void ResetDelegate(IntPtr smpl); + public delegate void ResetDelegate(ref LLamaSamplerNative smpl); + /// + /// Create a clone of this sampler + /// + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate IntPtr CloneDelegate(IntPtr smpl); + public unsafe delegate LLamaSamplerNative* CloneDelegate(ref LLamaSamplerNative smpl); + /// + /// Free all resources held by this sampler + /// + /// [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - public delegate void FreeDelegate(IntPtr smpl); - - // Struct fields corresponding to function pointers - public NameDelegate name; - public AcceptDelegate accept; - public ApplyDelegate apply; - public ResetDelegate reset; - public CloneDelegate clone; - public FreeDelegate free; + public delegate void FreeDelegate(ref LLamaSamplerNative smpl); + + public unsafe delegate* Name; + public unsafe delegate* Accept; + public unsafe delegate* Apply; + public unsafe delegate* Reset; + public unsafe delegate* Clone; + public unsafe delegate* Free; } /// /// /// /// llama_sampler +[StructLayout(LayoutKind.Sequential)] internal unsafe struct LLamaSamplerNative { - public LLamaSamplerINative* iface; - public IntPtr ctx; + /// + /// Holds the function pointers which make up the actual sampler + /// + public LLamaSamplerINative* Interface; + + /// + /// Any additional context this sampler needs, may be anything. We will use it + /// to hold a GCHandle. + /// + public IntPtr Context; } -internal class CustomSamplerWrapper +internal class CustomSamplerHandle { - public GCHandle Handle; - public ICustomSampler Sampler; + /// + /// This GCHandle roots this object, preventing it from being freed. + /// + private GCHandle _gcHandle; + + /// + /// A reference to the user code which implements the custom sampler + /// + private readonly ICustomSampler _sampler; + + private unsafe LLamaSamplerNative* _samplerNativePtr; + private unsafe LLamaSamplerINative* _samplerNativeInterfacePtr; + + private CustomSamplerHandle(ICustomSampler sampler) + { + _sampler = sampler; + } + + public static CustomSamplerHandle Create(ICustomSampler sampler) + { + var handle = new CustomSamplerHandle(sampler); + handle._gcHandle = GCHandle.Alloc(handle); + + unsafe + { + handle._samplerNativeInterfacePtr = (LLamaSamplerINative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerINative)); + handle._samplerNativeInterfacePtr->Name = (delegate*)Marshal.GetFunctionPointerForDelegate(Name); + handle._samplerNativeInterfacePtr->Accept = (delegate*)Marshal.GetFunctionPointerForDelegate(Accept); + handle._samplerNativeInterfacePtr->Apply = (delegate*)Marshal.GetFunctionPointerForDelegate(Apply); + handle._samplerNativeInterfacePtr->Reset = (delegate*)Marshal.GetFunctionPointerForDelegate(Reset); + handle._samplerNativeInterfacePtr->Clone = (delegate*)Marshal.GetFunctionPointerForDelegate(Clone); + handle._samplerNativeInterfacePtr->Free = (delegate*)Marshal.GetFunctionPointerForDelegate(Free); + + handle._samplerNativePtr = (LLamaSamplerNative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerNative)); + handle._samplerNativePtr->Context = (IntPtr)handle._gcHandle; + handle._samplerNativePtr->Interface = handle._samplerNativeInterfacePtr; + } + + return handle; + } + + /// + /// Get a pointer to a `llama_sampler` (LLamaSamplerNative) struct, suitable for passing to `llama_sampler_chain_add` + /// + /// + /// + public unsafe LLamaSamplerNative* GetLLamaSamplerPointer() + { + return _samplerNativePtr; + } + + private static CustomSamplerHandle GetSampler(ref LLamaSamplerNative smpl) + { + return (CustomSamplerHandle)GCHandle.FromIntPtr(smpl.Context).Target!; + } + + private static string Name(ref LLamaSamplerNative smpl) + { + return GetSampler(ref smpl)._sampler.Name; + } + + private static void Accept(ref LLamaSamplerNative smpl, LLamaToken token) + { + GetSampler(ref smpl)._sampler.Accept(token); + } + + private static void Apply(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative candidates) + { + GetSampler(ref smpl)._sampler.Apply(ref candidates); + } + + private static void Reset(ref LLamaSamplerNative smpl) + { + GetSampler(ref smpl)._sampler.Reset(); + } + + private static unsafe LLamaSamplerNative* Clone(ref LLamaSamplerNative smpl) + { + var sampler = GetSampler(ref smpl); + + return Create(sampler._sampler.Clone()).GetLLamaSamplerPointer(); + } + + private static unsafe void Free(ref LLamaSamplerNative smpl) + { + var sampler = GetSampler(ref smpl); + + if (sampler._samplerNativePtr != null) + { + Marshal.FreeHGlobal((IntPtr)sampler._samplerNativePtr); + sampler._samplerNativePtr = null; + } + + if (sampler._samplerNativeInterfacePtr != null) + { + Marshal.FreeHGlobal((IntPtr)sampler._samplerNativeInterfacePtr); + sampler._samplerNativeInterfacePtr = null; + } + + sampler._gcHandle.Free(); + + sampler._sampler.Dispose(); + } } /// /// A custom sampler stage for modifying logits or selecting a token /// public interface ICustomSampler + : IDisposable { /// - /// The name of this stage + /// The human readable name of this stage /// string Name { get; } /// - /// Apply this stage to a set of logits + /// Apply this stage to a set of logits. + /// This can modify logits or select a token (or both). + /// If logits are modified the Sorted flag must be set to false. /// + /// + /// If the logits are no longer sorted after the custom sampler has run it is critically important to + /// set Sorted=false. If unsure, always set it to false, this is a safe default. + /// /// void Apply(ref LLamaTokenDataArrayNative tokenData); @@ -584,10 +743,4 @@ public interface ICustomSampler /// Create a clone of this sampler /// ICustomSampler Clone(); - - /// - /// Free all unmanaged resources held by this sampler - /// - void Free(); -} -*/ \ No newline at end of file +} \ No newline at end of file