Skip to content

Reuse session for 60% speedup #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions example/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import * as tts from '../src';
import Worker from './worker.ts?worker';
import Worker2 from './worker2.ts?worker';

// required for e2e
Object.assign(window, { tts });

document.querySelector('#app')!.innerHTML = `
<button id="btn" type="button">Predict</button>
<button id="btn2" type="button">Predict with Session Reuse</button>
`

document.getElementById('btn')?.addEventListener('click', async () => {
Expand All @@ -26,3 +28,21 @@ document.getElementById('btn')?.addEventListener('click', async () => {
worker.terminate();
});
});

const mainWorker = new Worker2();

document.getElementById('btn2')?.addEventListener('click', async () => {
mainWorker.postMessage({
type: 'init',
text: "As the waves crashed against the shore, they carried tales of distant lands and adventures untold.",
voiceId: 'en_US-hfc_female-medium',
});

mainWorker.addEventListener('message', (event: MessageEvent<{ type: 'result', audio: Blob }>) => {
if (event.data.type != 'result') return;

const audio = new Audio();
audio.src = URL.createObjectURL(event.data.audio);
audio.play();
}, { once: true });
});
2 changes: 2 additions & 0 deletions example/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ import * as tts from '../src/index';
async function main(event: MessageEvent<tts.InferenceConfg & { type: 'init' }>) {
if (event.data?.type != 'init') return;

const start = performance.now();
const blob = await tts.predict({
text: event.data.text,
voiceId: event.data.voiceId,
});
console.log('Time taken:', performance.now() - start);

self.postMessage({ type: 'result', audio: blob })
}
Expand Down
19 changes: 19 additions & 0 deletions example/worker2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import * as tts from '../src/index';

const start = performance.now();
const TtsSession = await tts.TtsSession.create({
voiceId: 'en_US-hfc_female-medium',
});
console.log('Time taken to init session:', performance.now() - start);

async function main(event: MessageEvent<tts.InferenceConfg & { type: 'init' }>) {
if (event.data?.type != 'init') return;

const start = performance.now();
const blob = await TtsSession.predict(event.data.text);
console.log('Time taken:', performance.now() - start);

self.postMessage({ type: 'result', audio: blob })
}

self.addEventListener('message', main);
170 changes: 119 additions & 51 deletions src/inference.ts
Original file line number Diff line number Diff line change
@@ -1,65 +1,133 @@
import { InferenceConfg, ProgressCallback } from "./types";
import { InferenceConfg, ProgressCallback, VoiceId } from "./types";
import { HF_BASE, ONNX_BASE, PATH_MAP, WASM_BASE } from './fixtures';
import { readBlob, writeBlob } from './opfs';
import { fetchBlob } from './http.js';
import { pcm2wav } from './audio';

/**
* Run text to speech inference in new worker thread. Fetches the model
* first, if it has not yet been saved to opfs yet.
*/
export async function predict(config: InferenceConfg, callback?: ProgressCallback): Promise<Blob> {
const { createPiperPhonemize } = await import('./piper.js');
const ort = await import('onnxruntime-web');

const path = PATH_MAP[config.voiceId];
const input = JSON.stringify([{ text: config.text.trim() }]);

ort.env.allowLocalModels = false;
ort.env.wasm.numThreads = navigator.hardwareConcurrency;
ort.env.wasm.wasmPaths = ONNX_BASE;

const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`);
const modelConfig = JSON.parse(await modelConfigBlob.text());

const phonemeIds: string[] = await new Promise(async resolve => {
const module = await createPiperPhonemize({
print: (data: any) => {
resolve(JSON.parse(data).phoneme_ids);
},
printErr: (message: any) => {
throw new Error(message);
},
locateFile: (url: string) => {
if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`;
if (url.endsWith(".data")) return `${WASM_BASE}.data`;
return url;
}
});
interface TtsSessionOptions {
voiceId: VoiceId;
progress?: ProgressCallback;
}

module.callMain(["-l", modelConfig.espeak.voice, "--input", input, "--espeak_data", "/espeak-ng-data"]);
});
export class TtsSession {
ready = false;
voiceId: VoiceId;
waitReady: Promise<void>;
#createPiperPhonemize?: (moduleArg?: {}) => any;
#modelConfig?: any;
#ort?: typeof import("onnxruntime-web");
#ortSession?: import("onnxruntime-web").InferenceSession
#progressCallback?: ProgressCallback;

const speakerId = 0;
const sampleRate = modelConfig.audio.sample_rate;
const noiseScale = modelConfig.inference.noise_scale;
const lengthScale = modelConfig.inference.length_scale;
const noiseW = modelConfig.inference.noise_w;

const modelBlob = await getBlob(`${HF_BASE}/${path}`, callback);
const session = await ort.InferenceSession.create(await modelBlob.arrayBuffer());
const feeds = {
input: new ort.Tensor("int64", phonemeIds, [1, phonemeIds.length]),
input_lengths: new ort.Tensor("int64", [phonemeIds.length]),
scales: new ort.Tensor("float32", [noiseScale, lengthScale, noiseW])
constructor({ voiceId, progress }: TtsSessionOptions) {
this.voiceId = voiceId;
this.#progressCallback = progress;
this.waitReady = this.init();
}
if (Object.keys(modelConfig.speaker_id_map).length) {
Object.assign(feeds, { sid: new ort.Tensor("int64", [speakerId]) })

static async create(options: TtsSessionOptions) {
const session = new TtsSession(options);
await session.waitReady;
return session;
}

const { output: { data: pcm } } = await session.run(feeds);
async init() {
const { createPiperPhonemize } = await import("./piper.js");
this.#createPiperPhonemize = createPiperPhonemize;
this.#ort = await import("onnxruntime-web");

this.#ort.env.allowLocalModels = false;
this.#ort.env.wasm.numThreads = navigator.hardwareConcurrency;
this.#ort.env.wasm.wasmPaths = ONNX_BASE;

const path = PATH_MAP[this.voiceId];
const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`);
this.#modelConfig = JSON.parse(await modelConfigBlob.text());

const modelBlob = await getBlob(
`${HF_BASE}/${path}`,
this.#progressCallback
);
this.#ortSession = await this.#ort.InferenceSession.create(
await modelBlob.arrayBuffer()
);
}

async predict(text: string): Promise<Blob> {
await this.waitReady; // wait for the session to be ready

const input = JSON.stringify([{ text: text.trim() }]);

const phonemeIds: string[] = await new Promise(async (resolve) => {
const module = await this.#createPiperPhonemize!({
print: (data: any) => {
resolve(JSON.parse(data).phoneme_ids);
},
printErr: (message: any) => {
throw new Error(message);
},
locateFile: (url: string) => {
if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`;
if (url.endsWith(".data")) return `${WASM_BASE}.data`;
return url;
},
});

return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], { type: "audio/x-wav" });
module.callMain([
"-l",
this.#modelConfig.espeak.voice,
"--input",
input,
"--espeak_data",
"/espeak-ng-data",
]);
});

const speakerId = 0;
const sampleRate = this.#modelConfig.audio.sample_rate;
const noiseScale = this.#modelConfig.inference.noise_scale;
const lengthScale = this.#modelConfig.inference.length_scale;
const noiseW = this.#modelConfig.inference.noise_w;

const session = this.#ortSession!;
const feeds = {
input: new this.#ort!.Tensor("int64", phonemeIds, [1, phonemeIds.length]),
input_lengths: new this.#ort!.Tensor("int64", [phonemeIds.length]),
scales: new this.#ort!.Tensor("float32", [
noiseScale,
lengthScale,
noiseW,
]),
};
if (Object.keys(this.#modelConfig.speaker_id_map).length) {
Object.assign(feeds, {
sid: new this.#ort!.Tensor("int64", [speakerId]),
});
}

const {
output: { data: pcm },
} = await session.run(feeds);

return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], {
type: "audio/x-wav",
});
}
}

/**
* Run text to speech inference in new worker thread. Fetches the model
* first, if it has not yet been saved to opfs yet.
*/
export async function predict(
config: InferenceConfg,
callback?: ProgressCallback
): Promise<Blob> {
const session = new TtsSession({
voiceId: config.voiceId,
progress: callback,
});
return session.predict(config.text);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"target": "ES2020",
"useDefineForClassFields": true,
"module": "ESNext",
"lib": ["ES2020", "DOM", "DOM.Iterable","WebWorker"],
"lib": ["ES2022", "DOM", "DOM.Iterable","WebWorker"],
"skipLibCheck": true,
"allowJs": true,

Expand Down