diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs
index c487d41c..2579a503 100644
--- a/LLama.Examples/ExampleRunner.cs
+++ b/LLama.Examples/ExampleRunner.cs
@@ -35,6 +35,7 @@ public class ExampleRunner
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
+ { "Custom Sampling Pipeline", CustomSampler.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
diff --git a/LLama.Examples/Examples/CustomSampler.cs b/LLama.Examples/Examples/CustomSampler.cs
new file mode 100644
index 00000000..d2df8db2
--- /dev/null
+++ b/LLama.Examples/Examples/CustomSampler.cs
@@ -0,0 +1,114 @@
+using LLama.Common;
+using LLama.Examples.Extensions;
+using LLama.Native;
+using LLama.Sampling;
+
+namespace LLama.Examples.Examples
+{
+ public class CustomSampler
+ {
+ public static async Task Run()
+ {
+ var modelPath = UserSettings.GetModelPath();
+
+ var parameters = new ModelParams(modelPath);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
+
+ var ex = new StatelessExecutor(model, parameters);
+
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine("In this example a custom sampling pipeline with a custom sampler stage is being used. This demonstrates how to customise the samplers used, and " +
+ "how to create a completely custom sampler stage which modifies the logits or selects a token." +
+ "" +
+ "In this case the custom sampler stage removes the most likely token. This will probably produce bad results, it's just a demo!"
+ );
+ Console.ForegroundColor = ConsoleColor.White;
+
+ var inferenceParams = new InferenceParams
+ {
+ SamplingPipeline = new CustomSamplingPipeline(),
+ MaxTokens = 50
+ };
+
+ while (true)
+ {
+ Console.Write("\nQuestion: ");
+ Console.ForegroundColor = ConsoleColor.Green;
+ var prompt = Console.ReadLine();
+ Console.ForegroundColor = ConsoleColor.White;
+ Console.Write("Answer: ");
+ prompt = $"Question: {prompt?.Trim()} Answer: ";
+ await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner())
+ {
+ Console.Write(text);
+ }
+ }
+ }
+ }
+
+ public class CustomSamplingPipeline
+ : BaseSamplingPipeline
+ {
+ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
+ {
+ var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
+
+ // Take only the 10 most likely tokens
+ chain.AddTopK(10);
+
+ // Remove the most likely token
+ chain.AddCustom(new RemoveMostLikelyToken());
+
+ // Select from the distribution
+ chain.AddSoftmax();
+ chain.AddDistributionSampler(42);
+
+ return chain;
+ }
+ }
+
+ public class RemoveMostLikelyToken
+ : ICustomSampler
+ {
+ public string Name => "Remove Most Likely Token";
+
+ public void Apply(ref LLamaTokenDataArrayNative tokenData)
+ {
+ // Doesn't make sense to run this stage if there is only one candidate left
+ if (tokenData.Size <= 1)
+ return;
+
+ // Ensure token data is sorted, so most likely token is first.
+ // Note that this is a descending sort, the **largest** value is first.
+ if (!tokenData.Sorted)
+ tokenData.Data.Sort((a, b) => b.Logit.CompareTo(a.Logit));
+
+ // Make the most likely token impossible to pick
+ tokenData.Data[0].Logit = float.NegativeInfinity;
+
+ // It's **critically** important to set this if the logits are no longer sorted after the custom
+ // sampler has run. If you're not sure, it's always safer to set it to false.
+ //
+ // In this case, because the first logit has just been set to negative infinity
+ // the token data is definitely not sorted!
+ tokenData.Sorted = false;
+ }
+
+ public void Accept(LLamaToken token)
+ {
+ }
+
+ public void Reset()
+ {
+ }
+
+ public ICustomSampler Clone()
+ {
+ return new RemoveMostLikelyToken();
+ }
+
+ public void Dispose()
+ {
+ }
+ }
+}
diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index d101a833..d6ab139e 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -149,7 +149,7 @@ public struct LLamaTokenDataArrayNative
///
/// Number of LLamaTokenData in the array
///
- public ulong size;
+ private ulong _size;
///
/// The index in the array (i.e. not the token id)
@@ -167,13 +167,13 @@ public Span Data
{
unsafe
{
- return new Span(_data, checked((int)size));
+ return new Span(_data, checked((int)Size));
}
}
}
///
- /// Indicates if the items in the array are sorted
+ /// Indicates if the items in the array are sorted, so the most likely token is first
///
public bool Sorted
{
@@ -190,6 +190,20 @@ public long Selected
set => _selected = value;
}
+ ///
+ /// Number of LLamaTokenData in the array. Set this to shrink the array
+ ///
+ public ulong Size
+ {
+ get => _size;
+ set
+ {
+ if (value > _size)
+ throw new ArgumentOutOfRangeException(nameof(value), "Cannot set Size property to a larger value");
+ _size = value;
+ }
+ }
+
///
/// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
///
@@ -205,7 +219,7 @@ public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataA
native = new LLamaTokenDataArrayNative
{
_data = (LLamaTokenData*)handle.Pointer,
- size = (ulong)array.Data.Length,
+ Size = (ulong)array.Data.Length,
Sorted = array.Sorted
};
}
diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs
index 9cde52e4..902e34f2 100644
--- a/LLama/Native/SafeLLamaSamplerHandle.cs
+++ b/LLama/Native/SafeLLamaSamplerHandle.cs
@@ -162,7 +162,7 @@ public static SafeLLamaSamplerChainHandle Create(LLamaSamplerChainParams @params
/// The index of the stage to clone
public void AddClone(SafeLLamaSamplerChainHandle src, int index)
{
- if (index < 0 || index >= Count)
+ if (index < 0 || index >= src.Count)
throw new ArgumentOutOfRangeException(nameof(index));
llama_sampler_chain_add(
@@ -193,6 +193,22 @@ public void Remove(int index)
}
+ ///
+ /// Add a custom sampler stage
+ ///
+ ///
+ ///
+ public void AddCustom(TSampler sampler)
+ where TSampler : class, ICustomSampler
+ {
+ unsafe
+ {
+ var samplerHandle = CustomSamplerHandle.Create(sampler);
+ llama_sampler_chain_add(this, (IntPtr)samplerHandle.GetLLamaSamplerPointer());
+ }
+ }
+
+
///
/// Add a sampler which picks the most likely token.
///
@@ -502,70 +518,213 @@ public record struct LLamaLogitBias
public float Bias;
}
-/* todo: Custom sampler stuff
///
///
///
/// llama_sampler_i
-public struct LLamaSamplerINative
+[StructLayout(LayoutKind.Sequential)]
+internal struct LLamaSamplerINative
{
- // Delegate definitions for the function pointers
+ ///
+ /// Get the name of this sampler
+ ///
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate string NameDelegate(IntPtr smpl);
+ public delegate string NameDelegate(ref LLamaSamplerNative smpl);
+ ///
+ /// Update internal sampler state after a token has been chosen
+ ///
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate void AcceptDelegate(IntPtr smpl, LLamaToken token);
+ public delegate void AcceptDelegate(ref LLamaSamplerNative smpl, LLamaToken token);
+ ///
+ /// Apply this sampler to a set of logits
+ ///
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate void ApplyDelegate(IntPtr smpl, ref LLamaTokenDataArrayNative cur_p);
+ public delegate void ApplyDelegate(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative cur_p);
+ ///
+ /// Reset the internal state of this sampler
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate void ResetDelegate(IntPtr smpl);
+ public delegate void ResetDelegate(ref LLamaSamplerNative smpl);
+ ///
+ /// Create a clone of this sampler
+ ///
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate IntPtr CloneDelegate(IntPtr smpl);
+ public unsafe delegate LLamaSamplerNative* CloneDelegate(ref LLamaSamplerNative smpl);
+ ///
+ /// Free all resources held by this sampler
+ ///
+ ///
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- public delegate void FreeDelegate(IntPtr smpl);
-
- // Struct fields corresponding to function pointers
- public NameDelegate name;
- public AcceptDelegate accept;
- public ApplyDelegate apply;
- public ResetDelegate reset;
- public CloneDelegate clone;
- public FreeDelegate free;
+ public delegate void FreeDelegate(ref LLamaSamplerNative smpl);
+
+ public unsafe delegate* Name;
+ public unsafe delegate* Accept;
+ public unsafe delegate* Apply;
+ public unsafe delegate* Reset;
+ public unsafe delegate* Clone;
+ public unsafe delegate* Free;
}
///
///
///
/// llama_sampler
+[StructLayout(LayoutKind.Sequential)]
internal unsafe struct LLamaSamplerNative
{
- public LLamaSamplerINative* iface;
- public IntPtr ctx;
+ ///
+ /// Holds the function pointers which make up the actual sampler
+ ///
+ public LLamaSamplerINative* Interface;
+
+ ///
+ /// Any additional context this sampler needs, may be anything. We will use it
+ /// to hold a GCHandle.
+ ///
+ public IntPtr Context;
}
-internal class CustomSamplerWrapper
+internal class CustomSamplerHandle
{
- public GCHandle Handle;
- public ICustomSampler Sampler;
+ ///
+ /// This GCHandle roots this object, preventing it from being freed.
+ ///
+ private GCHandle _gcHandle;
+
+ ///
+ /// A reference to the user code which implements the custom sampler
+ ///
+ private readonly ICustomSampler _sampler;
+
+ private unsafe LLamaSamplerNative* _samplerNativePtr;
+ private unsafe LLamaSamplerINative* _samplerNativeInterfacePtr;
+
+ private CustomSamplerHandle(ICustomSampler sampler)
+ {
+ _sampler = sampler;
+ }
+
+ public static CustomSamplerHandle Create(ICustomSampler sampler)
+ {
+ var handle = new CustomSamplerHandle(sampler);
+ handle._gcHandle = GCHandle.Alloc(handle);
+
+ unsafe
+ {
+ handle._samplerNativeInterfacePtr = (LLamaSamplerINative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerINative));
+ handle._samplerNativeInterfacePtr->Name = (delegate*)Marshal.GetFunctionPointerForDelegate(Name);
+ handle._samplerNativeInterfacePtr->Accept = (delegate*)Marshal.GetFunctionPointerForDelegate(Accept);
+ handle._samplerNativeInterfacePtr->Apply = (delegate*)Marshal.GetFunctionPointerForDelegate(Apply);
+ handle._samplerNativeInterfacePtr->Reset = (delegate*)Marshal.GetFunctionPointerForDelegate(Reset);
+ handle._samplerNativeInterfacePtr->Clone = (delegate*)Marshal.GetFunctionPointerForDelegate(Clone);
+ handle._samplerNativeInterfacePtr->Free = (delegate*)Marshal.GetFunctionPointerForDelegate(Free);
+
+ handle._samplerNativePtr = (LLamaSamplerNative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerNative));
+ handle._samplerNativePtr->Context = (IntPtr)handle._gcHandle;
+ handle._samplerNativePtr->Interface = handle._samplerNativeInterfacePtr;
+ }
+
+ return handle;
+ }
+
+ ///
+ /// Get a pointer to a `llama_sampler` (LLamaSamplerNative) struct, suitable for passing to `llama_sampler_chain_add`
+ ///
+ ///
+ ///
+ public unsafe LLamaSamplerNative* GetLLamaSamplerPointer()
+ {
+ return _samplerNativePtr;
+ }
+
+ private static CustomSamplerHandle GetSampler(ref LLamaSamplerNative smpl)
+ {
+ return (CustomSamplerHandle)GCHandle.FromIntPtr(smpl.Context).Target!;
+ }
+
+ private static string Name(ref LLamaSamplerNative smpl)
+ {
+ return GetSampler(ref smpl)._sampler.Name;
+ }
+
+ private static void Accept(ref LLamaSamplerNative smpl, LLamaToken token)
+ {
+ GetSampler(ref smpl)._sampler.Accept(token);
+ }
+
+ private static void Apply(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative candidates)
+ {
+ GetSampler(ref smpl)._sampler.Apply(ref candidates);
+ }
+
+ private static void Reset(ref LLamaSamplerNative smpl)
+ {
+ GetSampler(ref smpl)._sampler.Reset();
+ }
+
+ private static unsafe LLamaSamplerNative* Clone(ref LLamaSamplerNative smpl)
+ {
+ var sampler = GetSampler(ref smpl);
+
+ return Create(sampler._sampler.Clone()).GetLLamaSamplerPointer();
+ }
+
+ private static unsafe void Free(ref LLamaSamplerNative smpl)
+ {
+ var sampler = GetSampler(ref smpl);
+
+ if (sampler._samplerNativePtr != null)
+ {
+ Marshal.FreeHGlobal((IntPtr)sampler._samplerNativePtr);
+ sampler._samplerNativePtr = null;
+ }
+
+ if (sampler._samplerNativeInterfacePtr != null)
+ {
+ Marshal.FreeHGlobal((IntPtr)sampler._samplerNativeInterfacePtr);
+ sampler._samplerNativeInterfacePtr = null;
+ }
+
+ sampler._gcHandle.Free();
+
+ sampler._sampler.Dispose();
+ }
}
///
/// A custom sampler stage for modifying logits or selecting a token
///
public interface ICustomSampler
+ : IDisposable
{
///
- /// The name of this stage
+ /// The human readable name of this stage
///
string Name { get; }
///
- /// Apply this stage to a set of logits
+ /// Apply this stage to a set of logits.
+ /// This can modify logits or select a token (or both).
+ /// If logits are modified the Sorted flag must be set to false.
///
+ ///
+ /// If the logits are no longer sorted after the custom sampler has run it is critically important to
+ /// set Sorted=false. If unsure, always set it to false, this is a safe default.
+ ///
///
void Apply(ref LLamaTokenDataArrayNative tokenData);
@@ -584,10 +743,4 @@ public interface ICustomSampler
/// Create a clone of this sampler
///
ICustomSampler Clone();
-
- ///
- /// Free all unmanaged resources held by this sampler
- ///
- void Free();
-}
-*/
\ No newline at end of file
+}
\ No newline at end of file