Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved FromBuffer() to accept Memory<byte> for better memory management #316

Merged
merged 2 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Whisper.net.Demo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ await Parser.Default.ParseArguments<Options>(args)
.WithParsedAsync(Demo);

async Task Demo(Options opt)

{
if (!File.Exists(opt.ModelName))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Buffers;
using System.Runtime.InteropServices;
using Whisper.net.Internals.Native;
using Whisper.net.Native;

namespace Whisper.net.Internals.ModelLoader;

internal class WhisperProcessorModelBufferLoader : IWhisperProcessorModelLoader
internal class WhisperProcessorModelMemoryLoader : IWhisperProcessorModelLoader
{
private readonly GCHandle pinnedBuffer;
private readonly MemoryHandle pinnedMemory;
private readonly WhisperAheads aHeads;
private readonly GCHandle? aheadsHandle;
private readonly UIntPtr bufferLength;

private readonly WhisperFactoryOptions options;

public WhisperProcessorModelBufferLoader(byte[] buffer, WhisperFactoryOptions options)
public WhisperProcessorModelMemoryLoader(Memory<byte> buffer, WhisperFactoryOptions options)
{
this.options = options;

pinnedBuffer = GCHandle.Alloc(buffer, GCHandleType.Pinned);
pinnedMemory = buffer.Pin();
aHeads = ModelLoaderUtils.GetWhisperAlignmentHeads(options.CustomAlignmentHeads, out aheadsHandle);
bufferLength = new UIntPtr((uint)buffer.Length);
}

public void Dispose()
{
pinnedBuffer.Free();
pinnedMemory.Dispose();
if (aheadsHandle.HasValue)
{
aheadsHandle.Value.Free();
}
}

public IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
public unsafe IntPtr LoadNativeContext(INativeWhisper nativeWhisper)
{
return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State(pinnedBuffer.AddrOfPinnedObject(), bufferLength,
return nativeWhisper.Whisper_Init_From_Buffer_With_Params_No_State((IntPtr)pinnedMemory.Pointer, bufferLength,
new WhisperContextParams()
{
UseGpu = options.UseGpu.AsByte(),
Expand Down
16 changes: 8 additions & 8 deletions Whisper.net/WhisperFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,30 @@ public static WhisperFactory FromPath(string path, WhisperFactoryOptions options
}

/// <summary>
/// Creates a factory that uses the ggml model from a buffer in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// Creates a factory that uses the ggml model from a buffer in memory in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// </summary>
/// <param name="buffer">The buffer with the model.</param>
/// <param name="memory">The memory buffer with the model.</param>
/// <returns>An instance to the same builder.</returns>
/// <remarks>
/// If you don't know where to find a ggml model, you can use <seealso cref="Ggml.WhisperGgmlDownloader"/> which is downloading a model from huggingface.co.
/// </remarks>
public static WhisperFactory FromBuffer(byte[] buffer)
public static WhisperFactory FromBuffer(Memory<byte> memory)
{
return FromBuffer(buffer, WhisperFactoryOptions.Default);
return FromBuffer(memory, WhisperFactoryOptions.Default);
}

/// <summary>
/// Creates a factory that uses the ggml model from a buffer in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// Creates a factory that uses the ggml model from a buffer in memory in order to create <seealso cref="WhisperProcessorBuilder"/>.
/// </summary>
/// <param name="buffer">The buffer with the model.</param>
/// <param name="memory">The memory buffer with the model.</param>
/// <param name="options">Thhe options for the factory and the loading of the model.</param>
/// <returns>An instance to the same builder.</returns>
/// <remarks>
/// If you don't know where to find a ggml model, you can use <seealso cref="Ggml.WhisperGgmlDownloader"/> which is downloading a model from huggingface.co.
/// </remarks>
public static WhisperFactory FromBuffer(byte[] buffer, WhisperFactoryOptions options)
public static WhisperFactory FromBuffer(Memory<byte> memory, WhisperFactoryOptions options)
{
return new WhisperFactory(new WhisperProcessorModelBufferLoader(buffer, options), options.DelayInitialization);
return new WhisperFactory(new WhisperProcessorModelMemoryLoader(memory, options), options.DelayInitialization);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion tests/Whisper.net.Maui.Tests/MainPage.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ private async void ContentPage_Loaded(object sender, EventArgs e)
await mauiStream.CopyToAsync(audioFileStream);
audioFileStream.Seek(0, SeekOrigin.Begin);

using var whisperFactory = WhisperFactory.FromBuffer(memoryStream.ToArray());
using var whisperFactory = WhisperFactory.FromBuffer(memoryStream.GetBuffer().AsMemory(0, (int)memoryStream.Length));
var whisperBuilder = whisperFactory.CreateBuilder();
using var whisperProcessor = whisperBuilder.Build();
LblResult.Text = string.Empty;
Expand Down
6 changes: 3 additions & 3 deletions tests/Whisper.net.Tests/FactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ public void CreateBuilder_WithFileModel_ShouldReturnBuilder()
}

[Fact]
public void CreateBuilder_WithBufferedModel_ShouldReturnBuilder()
public void CreateBuilder_WithMemoryModel_ShouldReturnBuilder()
{
var buffer = File.ReadAllBytes(model.ModelFile);
using var factory = WhisperFactory.FromBuffer(buffer);
var memoryBuffer = File.ReadAllBytes(model.ModelFile);
using var factory = WhisperFactory.FromBuffer(memoryBuffer);
var builder = factory.CreateBuilder();
builder.Should().NotBeNull();
}
Expand Down
Loading