From f73b3eee99f77d30e25cd83f4238c17ee624b7aa Mon Sep 17 00:00:00 2001 From: Sam Willis Date: Wed, 10 Jul 2024 19:41:21 +0200 Subject: [PATCH 1/2] Reuse session for 60% speedup --- example/index.ts | 20 ++++++ example/worker.ts | 2 + example/worker2.ts | 19 +++++ src/inference.ts | 170 +++++++++++++++++++++++++++++++-------------- 4 files changed, 160 insertions(+), 51 deletions(-) create mode 100644 example/worker2.ts diff --git a/example/index.ts b/example/index.ts index 69f7928..bf8b576 100644 --- a/example/index.ts +++ b/example/index.ts @@ -1,11 +1,13 @@ import * as tts from '../src'; import Worker from './worker.ts?worker'; +import Worker2 from './worker2.ts?worker'; // required for e2e Object.assign(window, { tts }); document.querySelector('#app')!.innerHTML = ` + ` document.getElementById('btn')?.addEventListener('click', async () => { @@ -26,3 +28,21 @@ document.getElementById('btn')?.addEventListener('click', async () => { worker.terminate(); }); }); + +const mainWorker = new Worker2(); + +document.getElementById('btn2')?.addEventListener('click', async () => { + mainWorker.postMessage({ + type: 'init', + text: "As the waves crashed against the shore, they carried tales of distant lands and adventures untold.", + voiceId: 'en_US-hfc_female-medium', + }); + + mainWorker.addEventListener('message', (event: MessageEvent<{ type: 'result', audio: Blob }>) => { + if (event.data.type != 'result') return; + + const audio = new Audio(); + audio.src = URL.createObjectURL(event.data.audio); + audio.play(); + }, { once: true }); +}); diff --git a/example/worker.ts b/example/worker.ts index 5775553..b95f187 100644 --- a/example/worker.ts +++ b/example/worker.ts @@ -3,10 +3,12 @@ import * as tts from '../src/index'; async function main(event: MessageEvent) { if (event.data?.type != 'init') return; + const start = performance.now(); const blob = await tts.predict({ text: event.data.text, voiceId: event.data.voiceId, }); + console.log('Time taken:', performance.now() - start); self.postMessage({ type: 'result', audio: blob }) } diff --git a/example/worker2.ts b/example/worker2.ts new file mode 100644 index 0000000..407ef4c --- /dev/null +++ b/example/worker2.ts @@ -0,0 +1,19 @@ +import * as tts from '../src/index'; + +const start = performance.now(); +const TtsSession = await tts.TtsSession.create({ + voiceId: 'en_US-hfc_female-medium', +}); +console.log('Time taken to init session:', performance.now() - start); + +async function main(event: MessageEvent) { + if (event.data?.type != 'init') return; + + const start = performance.now(); + const blob = await TtsSession.predict(event.data.text); + console.log('Time taken:', performance.now() - start); + + self.postMessage({ type: 'result', audio: blob }) +} + +self.addEventListener('message', main); diff --git a/src/inference.ts b/src/inference.ts index bc39141..e8fd7fd 100644 --- a/src/inference.ts +++ b/src/inference.ts @@ -1,65 +1,133 @@ -import { InferenceConfg, ProgressCallback } from "./types"; +import { InferenceConfg, ProgressCallback, VoiceId } from "./types"; import { HF_BASE, ONNX_BASE, PATH_MAP, WASM_BASE } from './fixtures'; import { readBlob, writeBlob } from './opfs'; import { fetchBlob } from './http.js'; import { pcm2wav } from './audio'; -/** - * Run text to speech inference in new worker thread. Fetches the model - * first, if it has not yet been saved to opfs yet. - */ -export async function predict(config: InferenceConfg, callback?: ProgressCallback): Promise { - const { createPiperPhonemize } = await import('./piper.js'); - const ort = await import('onnxruntime-web'); - - const path = PATH_MAP[config.voiceId]; - const input = JSON.stringify([{ text: config.text.trim() }]); - - ort.env.allowLocalModels = false; - ort.env.wasm.numThreads = navigator.hardwareConcurrency; - ort.env.wasm.wasmPaths = ONNX_BASE; - - const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`); - const modelConfig = JSON.parse(await modelConfigBlob.text()); - - const phonemeIds: string[] = await new Promise(async resolve => { - const module = await createPiperPhonemize({ - print: (data: any) => { - resolve(JSON.parse(data).phoneme_ids); - }, - printErr: (message: any) => { - throw new Error(message); - }, - locateFile: (url: string) => { - if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`; - if (url.endsWith(".data")) return `${WASM_BASE}.data`; - return url; - } - }); +interface TtsSessionOptions { + voiceId: VoiceId; + progress?: ProgressCallback; +} - module.callMain(["-l", modelConfig.espeak.voice, "--input", input, "--espeak_data", "/espeak-ng-data"]); - }); +export class TtsSession { + ready = false; + voiceId: VoiceId; + waitReady: Promise; + #createPiperPhonemize?: (moduleArg?: {}) => any; + #modelConfig?: any; + #ort?: typeof import("onnxruntime-web"); + #ortSession?: import("onnxruntime-web").InferenceSession + #progressCallback?: ProgressCallback; - const speakerId = 0; - const sampleRate = modelConfig.audio.sample_rate; - const noiseScale = modelConfig.inference.noise_scale; - const lengthScale = modelConfig.inference.length_scale; - const noiseW = modelConfig.inference.noise_w; - - const modelBlob = await getBlob(`${HF_BASE}/${path}`, callback); - const session = await ort.InferenceSession.create(await modelBlob.arrayBuffer()); - const feeds = { - input: new ort.Tensor("int64", phonemeIds, [1, phonemeIds.length]), - input_lengths: new ort.Tensor("int64", [phonemeIds.length]), - scales: new ort.Tensor("float32", [noiseScale, lengthScale, noiseW]) + constructor({ voiceId, progress }: TtsSessionOptions) { + this.voiceId = voiceId; + this.#progressCallback = progress; + this.waitReady = this.init(); } - if (Object.keys(modelConfig.speaker_id_map).length) { - Object.assign(feeds, { sid: new ort.Tensor("int64", [speakerId]) }) + + static async create(options: TtsSessionOptions) { + const session = new TtsSession(options); + await session.waitReady; + return session; } - const { output: { data: pcm } } = await session.run(feeds); + async init() { + const { createPiperPhonemize } = await import("./piper.js"); + this.#createPiperPhonemize = createPiperPhonemize; + this.#ort = await import("onnxruntime-web"); + + this.#ort.env.allowLocalModels = false; + this.#ort.env.wasm.numThreads = navigator.hardwareConcurrency; + this.#ort.env.wasm.wasmPaths = ONNX_BASE; + + const path = PATH_MAP[this.voiceId]; + const modelConfigBlob = await getBlob(`${HF_BASE}/${path}.json`); + this.#modelConfig = JSON.parse(await modelConfigBlob.text()); + + const modelBlob = await getBlob( + `${HF_BASE}/${path}`, + this.#progressCallback + ); + this.#ortSession = await this.#ort.InferenceSession.create( + await modelBlob.arrayBuffer() + ); + } + + async predict(text: string): Promise { + await this.waitReady; // wait for the session to be ready + + const input = JSON.stringify([{ text: text.trim() }]); + + const phonemeIds: string[] = await new Promise(async (resolve) => { + const module = await this.#createPiperPhonemize!({ + print: (data: any) => { + resolve(JSON.parse(data).phoneme_ids); + }, + printErr: (message: any) => { + throw new Error(message); + }, + locateFile: (url: string) => { + if (url.endsWith(".wasm")) return `${WASM_BASE}.wasm`; + if (url.endsWith(".data")) return `${WASM_BASE}.data`; + return url; + }, + }); - return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], { type: "audio/x-wav" }); + module.callMain([ + "-l", + this.#modelConfig.espeak.voice, + "--input", + input, + "--espeak_data", + "/espeak-ng-data", + ]); + }); + + const speakerId = 0; + const sampleRate = this.#modelConfig.audio.sample_rate; + const noiseScale = this.#modelConfig.inference.noise_scale; + const lengthScale = this.#modelConfig.inference.length_scale; + const noiseW = this.#modelConfig.inference.noise_w; + + const session = this.#ortSession!; + const feeds = { + input: new this.#ort!.Tensor("int64", phonemeIds, [1, phonemeIds.length]), + input_lengths: new this.#ort!.Tensor("int64", [phonemeIds.length]), + scales: new this.#ort!.Tensor("float32", [ + noiseScale, + lengthScale, + noiseW, + ]), + }; + if (Object.keys(this.#modelConfig.speaker_id_map).length) { + Object.assign(feeds, { + sid: new this.#ort!.Tensor("int64", [speakerId]), + }); + } + + const { + output: { data: pcm }, + } = await session.run(feeds); + + return new Blob([pcm2wav(pcm as Float32Array, 1, sampleRate)], { + type: "audio/x-wav", + }); + } +} + +/** + * Run text to speech inference in new worker thread. Fetches the model + * first, if it has not yet been saved to opfs yet. + */ +export async function predict( + config: InferenceConfg, + callback?: ProgressCallback +): Promise { + const session = new TtsSession({ + voiceId: config.voiceId, + progress: callback, + }); + return session.predict(config.text); } /** From 9b38ee2a2f989f2f839845a1c545f1dd5ccd82c9 Mon Sep 17 00:00:00 2001 From: Sam Willis Date: Wed, 10 Jul 2024 19:58:33 +0200 Subject: [PATCH 2/2] Bump tsconfig lib to ES2022 so that build works --- tsconfig.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsconfig.json b/tsconfig.json index 2057dee..1058710 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -3,7 +3,7 @@ "target": "ES2020", "useDefineForClassFields": true, "module": "ESNext", - "lib": ["ES2020", "DOM", "DOM.Iterable","WebWorker"], + "lib": ["ES2022", "DOM", "DOM.Iterable","WebWorker"], "skipLibCheck": true, "allowJs": true,