Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Sampler Stages #961

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class ExampleRunner
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Custom Sampling Pipeline", CustomSampler.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
Expand Down
114 changes: 114 additions & 0 deletions LLama.Examples/Examples/CustomSampler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using LLama.Common;
using LLama.Examples.Extensions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples
{
public class CustomSampler
{
public static async Task Run()
{
var modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var ex = new StatelessExecutor(model, parameters);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("In this example a custom sampling pipeline with a custom sampler stage is being used. This demonstrates how to customise the samplers used, and " +
"how to create a completely custom sampler stage which modifies the logits or selects a token." +
"" +
"In this case the custom sampler stage removes the most likely token. This will probably produce bad results, it's just a demo!"
);
Console.ForegroundColor = ConsoleColor.White;

var inferenceParams = new InferenceParams
{
SamplingPipeline = new CustomSamplingPipeline(),
MaxTokens = 50
};

while (true)
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner())
{
Console.Write(text);
}
}
}
}

public class CustomSamplingPipeline
: BaseSamplingPipeline
{
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

// Take only the 10 most likely tokens
chain.AddTopK(10);

// Remove the most likely token
chain.AddCustom(new RemoveMostLikelyToken());

// Select from the distribution
chain.AddSoftmax();
chain.AddDistributionSampler(42);

return chain;
}
}

public class RemoveMostLikelyToken
: ICustomSampler
{
public string Name => "Remove Most Likely Token";

public void Apply(ref LLamaTokenDataArrayNative tokenData)
{
// Doesn't make sense to run this stage if there is only one candidate left
if (tokenData.Size <= 1)
return;

// Ensure token data is sorted, so most likely token is first.
// Note that this is a descending sort, the **largest** value is first.
if (!tokenData.Sorted)
tokenData.Data.Sort((a, b) => b.Logit.CompareTo(a.Logit));

// Make the most likely token impossible to pick
tokenData.Data[0].Logit = float.NegativeInfinity;

// It's **critically** important to set this if the logits are no longer sorted after the custom
// sampler has run. If you're not sure, it's always safer to set it to false.
//
// In this case, because the first logit has just been set to negative infinity
// the token data is definitely not sorted!
tokenData.Sorted = false;
martindevans marked this conversation as resolved.
Show resolved Hide resolved
}

public void Accept(LLamaToken token)
{
}

public void Reset()
{
}

public ICustomSampler Clone()
{
return new RemoveMostLikelyToken();
}

public void Dispose()
{
}
}
}
22 changes: 18 additions & 4 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public struct LLamaTokenDataArrayNative
/// <summary>
/// Number of LLamaTokenData in the array
/// </summary>
public ulong size;
private ulong _size;

/// <summary>
/// The index in the array (i.e. not the token id)
Expand All @@ -167,13 +167,13 @@ public Span<LLamaTokenData> Data
{
unsafe
{
return new Span<LLamaTokenData>(_data, checked((int)size));
return new Span<LLamaTokenData>(_data, checked((int)Size));
}
}
}

/// <summary>
/// Indicates if the items in the array are sorted
/// Indicates if the items in the array are sorted, so the most likely token is first
/// </summary>
public bool Sorted
{
Expand All @@ -190,6 +190,20 @@ public long Selected
set => _selected = value;
}

/// <summary>
/// Number of LLamaTokenData in the array. Set this to shrink the array
/// </summary>
public ulong Size
{
get => _size;
set
{
if (value > _size)
throw new ArgumentOutOfRangeException(nameof(value), "Cannot set Size property to a larger value");
_size = value;
}
}

/// <summary>
/// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
/// </summary>
Expand All @@ -205,7 +219,7 @@ public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataA
native = new LLamaTokenDataArrayNative
{
_data = (LLamaTokenData*)handle.Pointer,
size = (ulong)array.Data.Length,
Size = (ulong)array.Data.Length,
Sorted = array.Sorted
};
}
Expand Down
Loading
Loading