Skip to content

Commit

Permalink
Merge pull request #24 from JacobLinCool/gc
Browse files Browse the repository at this point in the history
WhisperModel can be automatically freed by GC
  • Loading branch information
JacobLinCool authored Jan 29, 2024
2 parents fb27cc7 + 0f741db commit 7d27f7a
Show file tree
Hide file tree
Showing 17 changed files with 271 additions and 138 deletions.
5 changes: 5 additions & 0 deletions .changeset/clever-cherries-judge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"smart-whisper": minor
---

WhisperModel will now be automatically freed by the Node.js garbage collector if `.free()` has not been called previously.
3 changes: 2 additions & 1 deletion binding.gyp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
{
"target_name": "smart-whisper",
"sources": [
"src/binding.cc",
"src/binding/binding.cc",
"src/binding/common.cc",
"src/binding/model.cc",
"src/binding/transcribe.cc",
"<!@(node -p \"require('./dist/build.js').sources\")"
Expand Down
15 changes: 15 additions & 0 deletions examples/gc.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { WhisperModel, manager } from "../dist";

const fp = manager.resolve("tiny");

(async () => {
for (let i = 0; i < 5; i++) {
await scope();
global.gc?.();
}
})();

async function scope() {
const model = await WhisperModel.load(fp);
console.log(model.handle);
}
14 changes: 0 additions & 14 deletions src/binding.cc

This file was deleted.

37 changes: 31 additions & 6 deletions src/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,29 @@ const module = require(path.join(__dirname, "../build/Release/smart-whisper"));
/**
* A external handle to a model.
*/
export type Handle = unknown;
export type Handle = {
readonly "": unique symbol;
};

export interface Binding {
export namespace Binding {
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @param callback A callback that will be called with the handle to the model.
*/
load(file: string, gpu: boolean, callback: (handle: Handle) => void): void;
export declare function load(
file: string,
gpu: boolean,
callback: (handle: Handle) => void,
): void;

/**
* Release the memory of the model, it will be unusable after this.
* @param handle The handle to the model.
* @param callback A callback that will be called when the model is freed.
*/
free(handle: Handle, callback: () => void): void;
export declare function free(handle: Handle, callback: () => void): void;

/**
* Transcribe a PCM buffer.
Expand All @@ -34,16 +40,35 @@ export interface Binding {
* @param finish A callback that will be called when the transcription is finished.
* @param progress A callback that will be called when a new result is available.
*/
transcribe<Format extends TranscribeFormat>(
export declare function transcribe<Format extends TranscribeFormat>(
handle: Handle,
pcm: Float32Array,
params: Partial<TranscribeParams<Format>>,
finish: (results: TranscribeResult<Format>[]) => void,
progress: (result: TranscribeResult<Format>) => void,
): void;

export declare class WhisperModel {
private _ctx;
constructor(handle: Handle);
get handle(): Handle | null;
get freed(): boolean;
/**
* Release the memory of the model, it will be unusable after this.
* It's safe to call this multiple times, but it will only free the model once.
*/
free(): Promise<void>;
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @returns A promise that resolves to a {@link WhisperModel}.
*/
static load(file: string, gpu?: boolean): Promise<WhisperModel>;
}
}

/**
* The native binding for the underlying C++ addon.
*/
export const binding: Binding = module;
export const binding: typeof Binding = module;
18 changes: 18 additions & 0 deletions src/binding/binding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <napi.h>

#include "common.h"
#include "model.h"
#include "transcribe.h"

Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports.Set("transcribe", Napi::Function::New(env, Transcribe));
WhisperModel::Init(env, exports);

if (IsProduction(env.Global())) {
whisper_log_set([](ggml_log_level level, const char *text, void *user_data) {}, nullptr);
}

return exports;
}

NODE_API_MODULE(whisper, Init)
16 changes: 16 additions & 0 deletions src/binding/common.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "common.h"

Napi::Promise PromiseWorker::Promise() { return promise.Promise(); }

bool IsProduction(const Napi::Object global_env) {
Napi::Object process = global_env.Get("process").As<Napi::Object>();
Napi::Object env = process.Get("env").As<Napi::Object>();
Napi::Value node_env = env.Get("NODE_ENV");

if (!node_env.IsString()) {
return false;
}

Napi::String node_env_str = node_env.As<Napi::String>();
return node_env_str.Utf8Value() == "production";
}
17 changes: 17 additions & 0 deletions src/binding/common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
#ifndef _GUARD_SW_COMMON_H
#define _GUARD_SW_COMMON_H

#ifndef NAPI_VERSION
// Support Node.js 16+
#define NAPI_VERSION 8
#endif
#include <napi.h>

class PromiseWorker : public Napi::AsyncWorker {
public:
PromiseWorker(Napi::Env &env) : AsyncWorker(env), promise(Napi::Promise::Deferred::New(env)) {}

Napi::Promise Promise();

protected:
Napi::Promise::Deferred promise;
};

bool IsProduction(const Napi::Object global_env);

#endif
146 changes: 96 additions & 50 deletions src/binding/model.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "model.h"

class LoadModelWorker : public Napi::AsyncWorker {
class LoadModelWorker : public PromiseWorker {
public:
LoadModelWorker(Napi::Function &callback, const std::string &model_path,
LoadModelWorker(Napi::Env &env, const std::string &model_path,
struct whisper_context_params params)
: AsyncWorker(callback), model_path(model_path), params(params) {}
: PromiseWorker(env), model_path(model_path), params(params) {}

void Execute() override {
context = whisper_init_from_file_with_params_no_state(model_path.c_str(), params);
Expand All @@ -15,10 +15,12 @@ class LoadModelWorker : public Napi::AsyncWorker {
}

void OnOK() override {
Napi::HandleScope scope(Env());
Napi::External<whisper_context> contextHandle =
Napi::External<whisper_context>::New(Env(), context);
Callback().Call({contextHandle});
Napi::HandleScope scope(Env());
auto handle = Napi::External<whisper_context>::New(Env(), context);
auto constructor = Env().GetInstanceData<Napi::FunctionReference>();
auto model = constructor->New({handle});

promise.Resolve(model);
}

private:
Expand All @@ -27,73 +29,117 @@ class LoadModelWorker : public Napi::AsyncWorker {
whisper_context *context;
};

bool IsProduction(const Napi::Object global_env) {
Napi::Object process = global_env.Get("process").As<Napi::Object>();
Napi::Object env = process.Get("env").As<Napi::Object>();
Napi::Value nodeEnv = env.Get("NODE_ENV");
if (nodeEnv.IsString()) {
Napi::String nodeEnvStr = nodeEnv.As<Napi::String>();
std::string envStr = nodeEnvStr.Utf8Value();
return envStr == "production";
class FreeModelWorker : public PromiseWorker {
public:
FreeModelWorker(Napi::Env &env, whisper_context *context)
: PromiseWorker(env), context(context) {}

void Execute() override { whisper_free(context); }

void OnOK() override {
Napi::HandleScope scope(Env());
promise.Resolve(Env().Undefined());
}
return false;

private:
whisper_context *context;
};

Napi::Object WhisperModel::Init(Napi::Env env, Napi::Object exports) {
Napi::Function func = DefineClass(
env, "WhisperModel",
{
StaticMethod<&WhisperModel::Load>(
"load", static_cast<napi_property_attributes>(napi_writable | napi_configurable)),
InstanceMethod<&WhisperModel::Free>(
"free", static_cast<napi_property_attributes>(napi_writable | napi_configurable)),
InstanceAccessor(
"freed", &WhisperModel::GetFreed, nullptr,
static_cast<napi_property_attributes>(napi_enumerable | napi_configurable)),
InstanceAccessor(
"handle", &WhisperModel::GetHandle, nullptr,
static_cast<napi_property_attributes>(napi_enumerable | napi_configurable)),
});

auto constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);
env.SetInstanceData<Napi::FunctionReference>(constructor);

exports.Set("WhisperModel", func);
return exports;
}

Napi::Value LoadModel(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
WhisperModel::WhisperModel(const Napi::CallbackInfo &info) : Napi::ObjectWrap<WhisperModel>(info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (info.Length() != 3) {
if (info.Length() != 1) {
Napi::TypeError::New(env, "Wrong number of arguments").ThrowAsJavaScriptException();
return env.Null();
return;
}

std::string model_path = info[0].As<Napi::String>();
whisper_context *context = info[0].As<Napi::External<whisper_context>>().Data();
this->context = context;
}

struct whisper_context_params params;
params.use_gpu = info[1].As<Napi::Boolean>();
void WhisperModel::Finalize(Napi::Env env) {
if (context != nullptr) {
whisper_free(context);
context = nullptr;
}
}

Napi::Function callback = info[2].As<Napi::Function>();
Napi::Value WhisperModel::Load(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (IsProduction(env.Global())) {
whisper_log_set([](ggml_log_level level, const char *text, void *user_data) {}, nullptr);
if (info.Length() < 1 || info.Length() > 2) {
Napi::TypeError::New(env, "Wrong number of arguments").ThrowAsJavaScriptException();
return env.Null();
}

LoadModelWorker *worker = new LoadModelWorker(callback, model_path, params);
std::string model_path = info[0].As<Napi::String>();

whisper_context_params params;
params.use_gpu = info.Length() == 2 ? info[1].As<Napi::Boolean>() : true;

auto worker = new LoadModelWorker(env, model_path, params);
worker->Queue();

return env.Undefined();
return worker->Promise();
}

class FreeModelWorker : public Napi::AsyncWorker {
public:
FreeModelWorker(Napi::Function &callback, whisper_context *context)
: AsyncWorker(callback), context(context) {}
Napi::Value WhisperModel::Free(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

void Execute() override { whisper_free(context); }
if (info.Length() != 0) {
Napi::TypeError::New(env, "Wrong number of arguments").ThrowAsJavaScriptException();
return env.Null();
}

void OnOK() override {
Napi::HandleScope scope(Env());
Callback().Call({});
if (context == nullptr) {
auto deferred = Napi::Promise::Deferred::New(env);
deferred.Resolve(env.Undefined());
return deferred.Promise();
} else {
auto worker = new FreeModelWorker(env, context);
context = nullptr;
worker->Queue();
return worker->Promise();
}
}

private:
whisper_context *context;
};
Napi::Value WhisperModel::GetFreed(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

return Napi::Boolean::New(env, context == nullptr);
}

Napi::Value FreeModel(const Napi::CallbackInfo &info) {
Napi::Value WhisperModel::GetHandle(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (info.Length() != 2) {
Napi::TypeError::New(env, "Wrong number of arguments").ThrowAsJavaScriptException();
if (context == nullptr) {
return env.Null();
}

whisper_context *context = info[0].As<Napi::External<whisper_context>>().Data();

Napi::Function callback = info[1].As<Napi::Function>();

FreeModelWorker *worker = new FreeModelWorker(callback, context);
worker->Queue();

return env.Undefined();
return Napi::External<whisper_context>::New(env, context);
}
21 changes: 19 additions & 2 deletions src/binding/model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
#ifndef _GUARD_SW_MODEL_H
#define _GUARD_SW_MODEL_H

#include "common.h"
#include "whisper.h"

Napi::Value LoadModel(const Napi::CallbackInfo& info);
Napi::Value FreeModel(const Napi::CallbackInfo& info);
class WhisperModel : public Napi::ObjectWrap<WhisperModel> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports);

WhisperModel(const Napi::CallbackInfo &info);
void Finalize(Napi::Env env);

private:
whisper_context *context;
static Napi::Value Load(const Napi::CallbackInfo &info);
Napi::Value Free(const Napi::CallbackInfo &info);
Napi::Value GetFreed(const Napi::CallbackInfo &info);
Napi::Value GetHandle(const Napi::CallbackInfo &info);
};

#endif
Loading

0 comments on commit 7d27f7a

Please sign in to comment.