Skip to content

Commit

Permalink
feat: 🔥 use @diffusionstudio/vits-web for better TTS
Browse files Browse the repository at this point in the history
  • Loading branch information
avarayr committed Sep 20, 2024
1 parent 2e73eff commit 6c2b8a7
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 15 deletions.
Binary file modified bun.lockb
Binary file not shown.
8 changes: 7 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
},
"dependencies": {
"@ctrl/golang-template": "^1.4.1",
"@diffusionstudio/vits-web": "^1.0.3",
"@hattip/core": "^0.0.48",
"@hattip/router": "^0.0.48",
"@hattip/vite": "^0.0.48",
Expand Down Expand Up @@ -115,5 +116,10 @@
"workbox-core": "^7.1.0",
"workbox-precaching": "^7.1.0",
"workbox-window": "^7.1.0"
}
},
"trustedDependencies": [
"core-js",
"core-js-pure",
"protobufjs"
]
}
118 changes: 105 additions & 13 deletions src/hooks/useSpeechSynthesis.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,111 @@
import { useCallback, useState } from "react";
import { useCallback, useState, useEffect } from "react";
import Worker from "../workers/ttsWorker?worker";
import { VoiceId } from "@diffusionstudio/vits-web";

export const useSpeechSynthesis = () => {
type WorkerMessage =
| { type: "loadingProgress"; progress: number }
| { type: "loadingComplete" }
| { type: "availableVoices"; voices: VoiceId[] }
| { type: "result"; audio: Blob }
| { type: "error"; message: string };

export const useSpeechSynthesis = ({
enabled = true,
selectedVoice,
}: {
enabled?: boolean;
selectedVoice: VoiceId;
}) => {
const [isSpeaking, setIsSpeaking] = useState(false);
const [isLoading, setIsLoading] = useState(true);
const [loadingProgress, setLoadingProgress] = useState(0);
const [worker, setWorker] = useState<Worker | null>(null);
const [availableVoices, setAvailableVoices] = useState<VoiceId[]>([]);

useEffect(() => {
const ttsWorker = new Worker();

ttsWorker.onmessage = (event: MessageEvent<WorkerMessage>) => {
switch (event.data.type) {
case "loadingProgress":
setLoadingProgress(event.data.progress);
break;
case "loadingComplete":
setIsLoading(false);
break;
case "availableVoices":
setAvailableVoices(event.data.voices);
break;
case "error":
console.error("Error in TTS worker:", event.data.message);
break;
}
};
setWorker(ttsWorker);

const speak = useCallback((text: string): Promise<void> => {
return new Promise((resolve) => {
const utterance = new SpeechSynthesisUtterance(text);
utterance.onstart = () => setIsSpeaking(true);
utterance.onend = () => {
setIsSpeaking(false);
resolve();
};
window.speechSynthesis.speak(utterance);
});
ttsWorker.postMessage({ type: "getVoices" });

return () => {
ttsWorker.terminate();
setWorker(null);
};
}, []);

return { speak, isSpeaking };
useEffect(() => {
if (worker && selectedVoice && enabled) {
setIsLoading(true);
worker.postMessage({ type: "init", voiceId: selectedVoice });
}

return () => {
if (worker) {
worker.postMessage({ type: "terminate" });
}
};
}, [worker, selectedVoice, enabled]);

const speak = useCallback(
async (text: string): Promise<void> => {
if (!worker) return;

const emojiRegex = /(\p{Emoji_Presentation}|\p{Emoji}\uFE0F)/gu;
const cleanedText = text.replace(emojiRegex, "");

return new Promise((resolve, reject) => {
let audio: HTMLAudioElement | null = null;

const messageHandler = (event: MessageEvent<WorkerMessage>) => {
switch (event.data.type) {
case "result":
audio = new Audio(URL.createObjectURL(event.data.audio));
audio.onended = () => {
setIsSpeaking(false);
cleanup();
resolve();
};
setIsSpeaking(true);
void audio.play();
break;
case "error":
cleanup();
reject(new Error(event.data.message));
break;
}
};

const cleanup = () => {
worker.removeEventListener("message", messageHandler);
if (audio) {
audio.onended = null;
}
};

worker.addEventListener("message", messageHandler);
worker.postMessage({ type: "speak", text: cleanedText });
});
},
[worker],
);

return { speak, isSpeaking, isLoading, loadingProgress, availableVoices, selectedVoice };
};
91 changes: 91 additions & 0 deletions src/hooks/useVoiceFormatting.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { useMemo } from "react";
import { CountryCode, VoiceId } from "@diffusionstudio/vits-web";

type FormattedVoice = {
id: VoiceId;
flag: string;
name: string;
quality: {
text: string;
color: string;
};
};

const getFlag = (locale: string): string => {
const flagEmojis: Record<CountryCode, string> = {
ar_JO: "🇯🇴",
ca_ES: "🇪🇸",
cs_CZ: "🇨🇿",
da_DK: "🇩🇰",
de_DE: "🇩🇪",
el_GR: "🇬🇷",
en_GB: "🇬🇧",
en_US: "🇺🇸",
es_ES: "🇪🇸",
es_MX: "🇲🇽",
fa_IR: "🇮🇷",
fi_FI: "🇫🇮",
fr_FR: "🇫🇷",
hu_HU: "🇭🇺",
is_IS: "🇮🇸",
it_IT: "🇮🇹",
ka_GE: "🇬🇪",
kk_KZ: "🇰🇿",
lb_LU: "🇱🇺",
ne_NP: "🇳🇵",
nl_BE: "🇧🇪",
nl_NL: "🇳🇱",
no_NO: "🇳🇴",
pl_PL: "🇵🇱",
pt_BR: "🇧🇷",
pt_PT: "🇵🇹",
ro_RO: "🇷🇴",
ru_RU: "🇷🇺",
sk_SK: "🇸🇰",
sl_SI: "🇸🇮",
sr_RS: "🇷🇸",
sv_SE: "🇸🇪",
sw_CD: "🇨🇩",
tr_TR: "🇹🇷",
uk_UA: "🇺🇦",
vi_VN: "🇻🇳",
zh_CN: "🇨🇳",
};
return flagEmojis[locale as CountryCode] || "🏳️";
};

const getQualityColor = (quality: string): string => {
switch (quality.toLowerCase()) {
case "x_low":
return "text-red-800";
case "low":
return "text-red-500";
case "medium":
return "text-orange-500";
case "high":
return "text-green-500";
default:
return "text-gray-500";
}
};

export const useVoiceFormatting = (voices: VoiceId[]): FormattedVoice[] => {
const formattedVoices = useMemo(() => {
return voices
.filter((voice) => voice.startsWith("en_"))
.map((voice) => {
const [locale, name, quality] = voice.split("-");
return {
id: voice,
flag: getFlag(locale || "en_US"),
name: (name?.charAt(0)?.toUpperCase() ?? "") + (name?.slice(1) ?? "") || "Unknown",
quality: {
text: (quality?.charAt(0)?.toUpperCase() ?? "") + (quality?.slice(1) ?? "") || "Unknown",
color: getQualityColor(quality || "Unknown"),
},
};
});
}, [voices]);

return formattedVoices;
};
76 changes: 75 additions & 1 deletion src/layouts/ChaiMessage/components/VideoCallModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ import { useMessageGeneration } from "~/hooks/useMessageGeneration";
import { useSpeechRecognition } from "~/hooks/useSpeechRecognition";
import { useSpeechSynthesis } from "~/hooks/useSpeechSynthesis";
import { api } from "~/trpc/react";
import {
Select,
SelectItem,
SelectTrigger,
SelectContent,
SelectValue,
SelectGroup,
SelectLabel,
SelectSeparator,
} from "~/components/primitives/Select";
import { VoiceId } from "@diffusionstudio/vits-web";
import { useVoiceFormatting } from "~/hooks/useVoiceFormatting";
import { useLocalStorage } from "usehooks-ts";
import { ClientConsts } from "~/utils/client-consts";

type VideoCallModalProps = {
isOpen: boolean;
Expand All @@ -29,6 +43,8 @@ function getBraveDetected() {
}

export const VideoCallModal = ({ isOpen, onClose, chatId }: VideoCallModalProps) => {
"use no memo";

const { tryFollowMessageGeneration } = useMessageGeneration(chatId);

const [isSpeaking, setIsSpeaking] = useState(false);
Expand All @@ -46,7 +62,18 @@ export const VideoCallModal = ({ isOpen, onClose, chatId }: VideoCallModalProps)
cleanup: cleanupSpeechRecognition,
} = useSpeechRecognition();

const { speak, isSpeaking: isTTSSpeaking } = useSpeechSynthesis();
const [selectedVoice, setSelectedVoice] = useLocalStorage<VoiceId>(
ClientConsts.LocalStorageKeys.selectedVoice,
"en_US-hfc_female-medium",
);

const {
speak,
isSpeaking: isTTSSpeaking,
isLoading: isTTSLoading,
loadingProgress,
availableVoices,
} = useSpeechSynthesis({ enabled: isOpen, selectedVoice });
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const dataArrayRef = useRef<Uint8Array | null>(null);
Expand Down Expand Up @@ -299,6 +326,15 @@ export const VideoCallModal = ({ isOpen, onClose, chatId }: VideoCallModalProps)
};
}, []);

const formattedVoices = useVoiceFormatting(availableVoices);

const handleVoiceChange = useCallback(
(value: string) => {
setSelectedVoice(value as VoiceId);
},
[setSelectedVoice],
);

return (
<AnimatePresence>
{isOpen && (
Expand Down Expand Up @@ -329,6 +365,44 @@ export const VideoCallModal = ({ isOpen, onClose, chatId }: VideoCallModalProps)
</motion.div>
)}

{/* Voice selection */}
<Select value={selectedVoice} onValueChange={handleVoiceChange} disabled={isTTSLoading || isTTSSpeaking}>
<SelectTrigger className="w-full">
<SelectValue placeholder="Select a voice" />
</SelectTrigger>
<SelectContent>
{formattedVoices.map((voice) => (
<SelectItem key={voice.id} value={voice.id}>
<div className="flex items-center space-x-2">
<span className="text-xl">{voice.flag}</span>
<span>{voice.name}</span>
<span className={`ml-auto ${voice.quality.color}`}>{voice.quality.text}</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>

{/* TTS model loading progress */}
<AnimatePresence mode="popLayout">
{isTTSLoading && (
<motion.div
className="mb-4 mt-2 w-full rounded-lg bg-blue-900 p-4 text-sm text-blue-100"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
>
<div className="flex items-center justify-between">
<span>Loading TTS model...</span>
<span>{loadingProgress}%</span>
</div>
<div className="mt-2 h-2 w-full rounded-full bg-blue-200">
<div className="h-full rounded-full bg-blue-500" style={{ width: `${loadingProgress}%` }}></div>
</div>
</motion.div>
)}
</AnimatePresence>

{/* Main circle with outer progress ring */}
<div className="relative flex h-60 w-60 items-center justify-center sm:h-72 sm:w-72 md:h-80 md:w-80">
{/* Progress ring */}
Expand Down
1 change: 1 addition & 0 deletions src/utils/client-consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export const ClientConsts = {
areNotificationsEnabled: "areNotificationsEnabled",
dbSubscriptionID: "dbSubscriptionID",
areSilentNotifications: "areSilentNotifications",
selectedVoice: "selectedVoice",
},
/**
* TODO: Reduce this number once we figure out how to handle:
Expand Down
Loading

0 comments on commit 6c2b8a7

Please sign in to comment.