Skip to content

Commit

Permalink
Merge pull request #797 from kunci115/main
Browse files Browse the repository at this point in the history
live streaming socket
  • Loading branch information
manmay-nakhashi authored Jun 27, 2024
2 parents e802ac5 + 0628710 commit 86100b6
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 2 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ This script allows you to speak a single phrase with one or more voices.
```shell
python tortoise/do_tts.py --text "I'm going to speak this" --voice random --preset fast
```
### do socket streaming
```socket server
python tortoise/socket_server.py
```
will listen at port 5000


### faster inference read.py

This script provides tools for reading large amounts of text.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ py-cpuinfo
hjson
psutil
sounddevice
spacy==3.7.5
50 changes: 50 additions & 0 deletions tortoise/socket_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import socket
import sounddevice as sd
import numpy as np

def play_audio_stream(client_socket):
buffer = b''
stream = sd.OutputStream(samplerate=24000, channels=1, dtype='float32')
stream.start()

try:
while True:
chunk = client_socket.recv(1024)
if b"END_OF_AUDIO" in chunk:
buffer += chunk.replace(b"END_OF_AUDIO", b"")
if buffer:
audio_array = np.frombuffer(buffer, dtype=np.float32)
stream.write(audio_array)
break

buffer += chunk
while len(buffer) >= 4096:
audio_chunk = buffer[:4096]
audio_array = np.frombuffer(audio_chunk, dtype=np.float32)
stream.write(audio_array)
buffer = buffer[4096:]

finally:
stream.stop()
stream.close()

def send_text_to_server(character_name, text, server_ip='localhost', server_port=5000):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect((server_ip, server_port))

try:
data = f"{character_name}|{text}"
client_socket.sendall(data.encode('utf-8'))

play_audio_stream(client_socket)

print("Audio playback finished.")

finally:
client_socket.close()


if __name__ == "__main__":
character_name ="deniro"
text = "Hello This is just for a live speaking test"
send_text_to_server(character_name, text)
83 changes: 83 additions & 0 deletions tortoise/socket_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import spacy
import threading
import socket
from tortoise.api_fast import TextToSpeech
from utils.audio import load_voices

tts = TextToSpeech()
nlp = spacy.load("en_core_web_sm")


def generate_audio_stream(text, tts, voice_samples):
print(f"Generating audio stream...: {text}")
voice_samples, conditioning_latents = load_voices([voice_samples])
stream = tts.tts_stream(
text,
voice_samples=voice_samples,
conditioning_latents=conditioning_latents,
verbose=True,
stream_chunk_size=40 # Adjust chunk size as needed
)
for audio_chunk in stream:
yield audio_chunk


def split_text(text, max_length=200):
doc = nlp(text)
chunks = []
chunk = []
length = 0

for sent in doc.sents:
sent_length = len(sent.text)
if length + sent_length > max_length:
chunks.append(' '.join(chunk))
chunk = []
length = 0
chunk.append(sent.text)
length += sent_length + 1

if chunk:
chunks.append(' '.join(chunk))

return chunks


def handle_client(client_socket, tts):
try:
while True:
data = client_socket.recv(1024).decode('utf-8')
if not data:
break
character_name, text = data.split('|', 1)
text_chunks = split_text(text, max_length=200)
print(text_chunks)
for chunk in text_chunks:
audio_stream = generate_audio_stream(chunk, tts, character_name)

for audio_chunk in audio_stream:
audio_data = audio_chunk.cpu().numpy().flatten()
client_socket.sendall(audio_data.tobytes())

client_socket.sendall(b"END_OF_AUDIO")

finally:
client_socket.close()
print("Client disconnected.")


def start_server():
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('0.0.0.0', 5000))
server.listen(5)
print("Server listening on port 5000")

while True:
client_socket, addr = server.accept()
print(f"Accepted connection from {addr}")
client_handler = threading.Thread(target=handle_client, args=(client_socket, tts))
client_handler.start()


if __name__ == "__main__":
start_server()
14 changes: 12 additions & 2 deletions tortoise/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,30 @@ def get_voices(extra_voice_dirs=[]):
voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth'))
return voices

def save_pth(conds, save_path):
torch.save(conds, save_path)


def load_voice(voice, extra_voice_dirs=[]):
if voice == 'random':
return None, None

voices = get_voices(extra_voice_dirs)
paths = voices[voice]
pth_files = [p for p in paths if p.endswith('.pth')]
if len(paths) == 1 and paths[0].endswith('.pth'):
return None, torch.load(paths[0])
else:
conds = []
for cond_path in paths:
c = load_audio(cond_path, 22050)
conds.append(c)
if not cond_path.endswith('.pth'):
c = load_audio(cond_path, 22050)
conds.append(c)

if not pth_files:
pth_save_path = os.path.join(os.path.dirname(paths[0]), f"{voice}.pth")
save_pth(conds, pth_save_path)

return conds, None


Expand Down

0 comments on commit 86100b6

Please sign in to comment.