Skip to content

Commit

Permalink
Merge pull request #560 from TEN-framework/feature/bytedance-tts-flush
Browse files Browse the repository at this point in the history
chore: flush cmd and latency
  • Loading branch information
plutoless authored Jan 8, 2025
2 parents 3602652 + 3b69e0e commit c0f2956
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 39 deletions.
120 changes: 83 additions & 37 deletions agents/ten_packages/extension/bytedance_tts/bytedance_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import json
import gzip
import asyncio
import threading
from datetime import datetime


MESSAGE_TYPES = {
Expand All @@ -36,6 +38,8 @@
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}

LATENCY_SAMPLE_INTERVAL_MS = 5


@dataclass
class TTSConfig(BaseConfig):
Expand Down Expand Up @@ -88,18 +92,31 @@ def __init__(self, config: TTSConfig, ten_env: AsyncTenEnv) -> None:
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
self.default_header = bytearray(b"\x11\x10\x11\x00")
self._cancel = threading.Event()

# Latency.
self._latest_record_time = None

def is_cancelled(self) -> bool:
return self._cancel.is_set()

async def cancel(self) -> None:
self._cancel.set()

async def connect(self) -> None:
header = {"Authorization": f"Bearer; {self.config.token}"}
self.websocket = await websockets.connect(
self.config.api_url,
extra_headers=header,
ping_interval=None,
close_timeout=1, # Fast close, as the `flush` cmd will close the connection.
)
self.ten_env.log_info("Websocket connection established.")

async def close(self) -> None:
if self.websocket is not None:
await self.websocket.close()
self.websocket = None
self.ten_env.log_info("Websocket connection closed.")
else:
self.ten_env.log_info("Websocket is not connected.")
Expand All @@ -118,39 +135,39 @@ def parse_response(self, response: websockets.Data) -> Tuple[bytes, bool]:
reserved = response[3]
header_extensions = response[4 : header_size * 4]
payload = response[header_size * 4 :]
self.ten_env.log_info(
self.ten_env.log_debug(
f"Protocol version: {protocol_version:#x} - version {protocol_version}"
)
self.ten_env.log_info(
self.ten_env.log_debug(
f"Header size: {header_size:#x} - {header_size * 4} bytes"
)
self.ten_env.log_info(
self.ten_env.log_debug(
f"Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}"
)
self.ten_env.log_info(
self.ten_env.log_debug(
f"Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}"
)
self.ten_env.log_info(
self.ten_env.log_debug(
f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}"
)
self.ten_env.log_info(
self.ten_env.log_debug(
f"Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}"
)
self.ten_env.log_info(f"Reserved: {reserved:#04x}")
self.ten_env.log_debug(f"Reserved: {reserved:#04x}")

if header_size != 1:
self.ten_env.log_info(f"Header extensions: {header_extensions}")
self.ten_env.log_debug(f"Header extensions: {header_extensions}")

if message_type == 0xB: # audio-only server response
if message_type_specific_flags == 0: # no sequence number as ACK
self.ten_env.log_info("Payload size: 0")
self.ten_env.log_debug("Payload size: 0")
return None, False
else:
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload = payload[8:]
self.ten_env.log_info(f"Sequence number: {sequence_number}")
self.ten_env.log_info(f"Payload size: {payload_size} bytes")
self.ten_env.log_debug(f"Sequence number: {sequence_number}")
self.ten_env.log_debug(f"Payload size: {payload_size} bytes")
if sequence_number < 0:
return payload, True
else:
Expand All @@ -171,14 +188,34 @@ def parse_response(self, response: websockets.Data) -> Tuple[bytes, bool]:
payload = payload[4:]
if message_compression == 1:
payload = gzip.decompress(payload)
self.ten_env.log_info(f"Frontend message: {payload}")
self.ten_env.log_debug(f"Frontend message: {payload}")
else:
self.ten_env.log_error("undefined message type!")
return None, True

def record_latency(self, request_id: str, start: datetime) -> None:
end_time = datetime.now()

if self._latest_record_time:
sample_interval = datetime.now() - self._latest_record_time
if sample_interval.total_seconds() < LATENCY_SAMPLE_INTERVAL_MS:
return

self._latest_record_time = end_time
latency = int((end_time - start).total_seconds() * 1000)
self.ten_env.log_info(f"Request ({request_id}), ttfb {latency}ms.")

async def text_to_speech_stream(self, text: str) -> AsyncIterator[bytes]:
ws = self.websocket
if ws is None:
await self.connect()
ws = self.websocket

start_ms = datetime.now()
request_id = str(uuid.uuid4())

request = copy.deepcopy(self.request_template)
request["request"]["reqid"] = str(uuid.uuid4())
request["request"]["reqid"] = request_id
request["request"]["text"] = text
request["user"]["uid"] = str(uuid.uuid4())

Expand All @@ -192,27 +229,36 @@ async def text_to_speech_stream(self, text: str) -> AsyncIterator[bytes]:
# payload
full_request.extend(request_bytes)

if self.websocket is not None:
try:
await self.websocket.send(full_request)
self.ten_env.log_info(f"Sent: {request}")

while True:
resp = await self.websocket.recv()
payload, done = self.parse_response(resp)
if payload:
yield payload

if done:
self.ten_env.log_info(
f"Response is completed for request: {request['request']['reqid']}."
)
break

except websockets.exceptions.ConnectionClosedError as e:
self.ten_env.log_info(f"Connection closed with error: {e}.")
await self.reconnect()
except asyncio.TimeoutError:
self.ten_env.log_info("Timeout waiting for response.")
else:
self.ten_env.log_error("Websocket is not connected.")
try:
await ws.send(full_request)
self.ten_env.log_info(f"Sent: {request}")

while True:
if self.is_cancelled():
self.ten_env.log_info(f"Request ({request_id}) has been cancelled.")

# Current connection should be closed, as the server will not drop the remain data.
await self.close()
self._cancel.clear()
break

resp = await ws.recv()
payload, done = self.parse_response(resp)

if payload:
yield payload
self.record_latency(request_id, start_ms)

if done:
self.ten_env.log_info(
f"Response is completed for request: {request_id}."
)
break

except websockets.exceptions.ConnectionClosedError as e:
self.ten_env.log_error(
f"Connection is closed with error: {e}, request: {request_id}."
)
await self.connect()
except asyncio.TimeoutError:
self.ten_env.log_error("Timeout waiting for response.")
2 changes: 1 addition & 1 deletion agents/ten_packages/extension/bytedance_tts/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ async def on_request_tts(
await self.send_audio_out(ten_env, audio_data)

async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None:
return await super().on_cancel_tts(ten_env)
await self.client.cancel()
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
asyncio
websockets
websockets==13.1

0 comments on commit c0f2956

Please sign in to comment.