Skip to content

Commit

Permalink
tts/stt POC Still need to clean it up but this works
Browse files Browse the repository at this point in the history
  • Loading branch information
mikejgray committed Dec 18, 2023
1 parent 658e1b9 commit b050ba0
Show file tree
Hide file tree
Showing 10 changed files with 447 additions and 112 deletions.
2 changes: 1 addition & 1 deletion neon_iris/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _send_utterance(self, utterance: str, lang: str,
self._send_serialized_message(serialized)

def _send_audio(self, audio_file: str, lang: str,
username: str, user_profiles: list,
username: Optional[str], user_profiles: Optional[list],
context: Optional[dict] = None):
context = context or dict()
audio_data = encode_file_to_base64_string(audio_file)
Expand Down
1 change: 1 addition & 0 deletions neon_iris/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .web_sat import UserInput, UserInputResponse # noqa
16 changes: 16 additions & 0 deletions neon_iris/models/web_sat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""API data models for the WebSAT API."""
from typing import Optional
from pydantic import BaseModel

class UserInput(BaseModel):
"""UserInput is the input data model for the WebSAT API."""
utterance: Optional[str] = ""
audio_input: Optional[str] = ""
session_id: str = "websat0000"

class UserInputResponse(BaseModel):
"""UserInputResponse is the response data model for the WebSAT API."""
utterance: Optional[str] = ""
audio_output: Optional[str] = ""
session_id: str = "websat0000"
transcription: str
30 changes: 0 additions & 30 deletions neon_iris/static/index.html

This file was deleted.

148 changes: 148 additions & 0 deletions neon_iris/static/scripts/ui.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
function submitMessage() {
const inputElement = document.getElementById("chatInput");
const userMessage = inputElement.value.trim();

if (userMessage !== "") {
const userMessageDiv = createMessageDiv("user", userMessage);
appendMessageToHistory(userMessageDiv);

// Save the message to localStorage
saveMessageToLocalStorage("user", userMessage);

inputElement.value = "";

// Get AI response and update the chat history
getAIResponse(userMessage); // Pass the user message to the function
}
}

async function getAIResponse(text = "", recording = "") {
try {
const payload =
text !== "" && recording === ""
? { utterance: text }
: { audio_input: recording };
// Make the POST request to the server
const response = await fetch("/user_input", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(payload), // Send the user message in the body
});

// Check if the response is okay
if (!response.ok) {
throw new Error("Network response was not ok: " + response.statusText);
}

// Convert the response payload into JSON
const data = await response.json();
console.debug(data, null, 4);

// Assuming 'data' contains the AI response in a property named 'reply'
const aiMessage = data.transcription;

// Add in the user's transcription if STT
if (text === "" && recording !== "") {
const userMessage = createMessageDiv("user", data.utterance);
appendMessageToHistory(userMessage);
saveMessageToLocalStorage("user", data.utterance);
}

// Create the AI message div and append it to the history
const aiMessageDiv = createMessageDiv("ai", aiMessage);
appendMessageToHistory(aiMessageDiv);

// Save the AI message to localStorage
saveMessageToLocalStorage("ai", aiMessage);

// Play the TTS audio
const audioBlob = base64ToBlob(data.audio_output, "audio/wav");
const audioUrl = URL.createObjectURL(audioBlob);
const audio = new Audio(audioUrl);
audio.type = "audio/wav";
await audio.play();
audio.onended = () => {
if (shouldListen && myVad) {
myVad.start();
} else {
myVad.pause();
}
};
} catch (error) {
console.error("Error fetching AI response:", error);
// Handle the error, such as showing a message to the user
}
}

function simulateAIResponse() {
setTimeout(() => {
const aiMessage = "This is a sample AI response.";
const aiMessageDiv = createMessageDiv("ai", aiMessage);
appendMessageToHistory(aiMessageDiv);

// Save the AI response to localStorage
saveMessageToLocalStorage("ai", aiMessage);
}, 1000); // Simulated delay of 1 second
}

function createMessageDiv(sender, message) {
const messageDiv = document.createElement("div");
messageDiv.className = `${sender}-message`;
messageDiv.textContent = message;
return messageDiv;
}

function appendMessageToHistory(messageDiv) {
const messageContainer = document.getElementById("chatHistory");
messageContainer.appendChild(messageDiv);
setTimeout(() => {
messageContainer.scrollTop = messageContainer.scrollHeight;
}, 0);
}

function saveMessageToLocalStorage(sender, message) {
// Retrieve existing chat history from localStorage
const chatHistory = JSON.parse(localStorage.getItem("chatHistory")) || [];

// Add the new message to the chat history
chatHistory.push({ sender, message });

// Store the updated chat history back in localStorage
localStorage.setItem("chatHistory", JSON.stringify(chatHistory));
}

function base64ToBlob(base64, mimeType) {
const byteCharacters = atob(base64.replace(/^data:audio\/wav;base64,/, ""));
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
return new Blob([byteArray], { type: mimeType });
}

// Load chat history from localStorage when the page loads
window.addEventListener("load", () => {
const chatHistory = JSON.parse(localStorage.getItem("chatHistory")) || [];

for (const { sender, message } of chatHistory) {
const messageDiv = createMessageDiv(sender, message);
appendMessageToHistory(messageDiv);
}
});

document.addEventListener("DOMContentLoaded", function () {
// Get the input element
const inputElement = document.getElementById("chatInput");

// Add the keydown event listener to the input element
inputElement.addEventListener("keydown", function (event) {
// Check if Enter was pressed, or Ctrl+Enter
if (event.key === "Enter" && (event.ctrlKey || !event.shiftKey)) {
event.preventDefault();
submitMessage();
}
});
});
33 changes: 21 additions & 12 deletions neon_iris/static/scripts/websocket.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ function float32To16BitPCM(output, offset, input) {
}
}

function wavBlobToBase64(blob) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.readAsDataURL(blob);
reader.onloadend = () => {
const base64data = reader.result;
// Extract the base64 part
const base64String = base64data.split(",")[1];
resolve(base64String);
};
reader.onerror = (error) => {
reject(error);
};
});
}

let shouldListen = false; // Global state flag for controlling VAD listening state
let myVad; // VAD instance
let isVadRunning = false;
Expand All @@ -53,11 +69,12 @@ async function initializeVad() {
}
}

function handleSpeechEnd(audio) {
async function handleSpeechEnd(audio) {
const wavBlob = float32ArrayToWavBlob(audio);
const audioUrl = URL.createObjectURL(wavBlob);
const audioOutput = await wavBlobToBase64(wavBlob);

// Save the audio as a file
// Save the spoken audio as a downloadable file
const downloadArea = document.getElementById("download-area");
if (downloadArea) {
downloadArea.innerHTML = "";
Expand All @@ -75,16 +92,8 @@ function handleSpeechEnd(audio) {
shouldListen = false;
}

// Play back the audio
const playbackAudio = new Audio(audioUrl);
playbackAudio.play();
playbackAudio.onended = () => {
if (shouldListen && myVad) {
myVad.start();
} else {
myVad.pause();
}
};
// Send audio to STT
getAIResponse("", audioOutput);
}

function toggleListeningState() {
Expand Down
51 changes: 44 additions & 7 deletions neon_iris/static/styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ a:hover {
cursor: pointer;
outline: none;
margin-bottom: 10px;
margin-top: 10px;
transition: background-color 0.3s;
background-color: #4caf50;
max-width: 50%;
align-content: center;
align-self: center;
}
#startButton.listening {
background-color: #03a9f4;
Expand All @@ -51,12 +55,6 @@ a:hover {
background-color: #333;
margin: 0;
}
.chat-container {
display: flex;
flex-direction: column;
height: 100%;
background-color: #1a1a1a;
}
.chat-header {
padding: 20px;
background-color: #333;
Expand All @@ -72,10 +70,19 @@ a:hover {
); /* Slight increase in size on hover for dynamic effect */
}
.chat-window {
flex: 1;
display: flex;
flex-direction: column;
padding: 20px;
overflow: auto;
}
#chatHistory {
display: flex;
flex-direction: column;
align-items: flex-start; /* Align items to the start by default */
height: 100%;
overflow-y: auto; /* Allows scrolling if content overflows */
background-color: #1a1a1a; /* Dark background for the chat container */
}
.input-area {
display: flex;
padding: 20px;
Expand Down Expand Up @@ -109,6 +116,36 @@ a:hover {
text-align: center;
font-size: 1.5em;
}
/* Style for user messages */
.user-message {
background-color: #007bff; /* Blue background for user messages */
color: #fff; /* White text color for user messages */
padding: 5px 10px;
margin: 5px 0;
border-radius: 10px;
align-self: flex-end; /* Right-align user messages */
max-width: 60%;
word-wrap: break-word; /* Wrap long words if needed */
}

/* Style for AI messages */
.ai-message {
background-color: #e0e0e0; /* Gray background for AI messages */
color: #000; /* Black text color for AI messages */
padding: 5px 10px;
margin: 5px 0;
border-radius: 10px;
align-self: flex-start; /* Left-align AI messages */
max-width: 60%;
word-wrap: break-word; /* Wrap long words if needed */
}

.chat-container {
display: flex;
flex-direction: column;
height: 100%;
background-color: #1a1a1a;
}

/* Responsive design adjustments */
@media (max-width: 768px) {
Expand Down
5 changes: 4 additions & 1 deletion neon_iris/static/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
</div>
<button id="startButton">Start Recording</button>
<div class="chat-window" id="chatWindow">
<!-- Chat messages will appear here -->
<div id="chatHistory">
<!-- Chat messages will appear here -->
</div>
</div>
<div id="download-area"></div>
<div class="input-area">
Expand All @@ -29,4 +31,5 @@
<!-- AI Code -->
<script src="/static/scripts/websocket.js"></script>
<script src="/static/scripts/audio.js"></script>
<script src="/static/scripts/ui.js"></script>
</html>
4 changes: 2 additions & 2 deletions neon_iris/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str,
def on_user_input(self, utterance: str,
chat_history: List[Tuple[str, str]],
audio_input: str,
client_session: str) -> (List[Tuple[str, str]], str, str, None, str):
client_session: str):
"""
Callback to handle textual user input
@param utterance: String utterance submitted by the user
@returns: Input box contents, Updated chat history, Gradio session ID, audio input, audio output
"""
input_time = time()
LOG.debug(f"Input received")
LOG.debug("Input received")
if not self._await_response.wait(30):
LOG.error("Previous response not completed after 30 seconds")
in_queue = time() - input_time
Expand Down
Loading

0 comments on commit b050ba0

Please sign in to comment.