-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathwebsocket_client.py
479 lines (417 loc) · 17.8 KB
/
websocket_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import argparse
import asyncio
import datetime
import json
import logging
import os
import signal
import sys
import urllib.parse
from typing import Any, AsyncGenerator, Awaitable, Literal
import aiohttp
import numpy as np
import pyee.asyncio
import sounddevice
from websockets import exceptions as ws_exceptions
from websockets.asyncio import client as ws_client
class LocalAudioSink:
"""
A sink for audio. Buffered audio is played using the default audio device.
Args:
sample_rate: The sample rate to use for audio playback. Defaults to 48kHz.
"""
def __init__(self, sample_rate: int = 48000) -> None:
self._sample_rate = sample_rate
self._buffer: bytearray = bytearray()
def callback(outdata: np.ndarray, frame_count, time, status):
output_frame_size = len(outdata) * 2
next_frame = self._buffer[:output_frame_size]
self._buffer[:] = self._buffer[output_frame_size:]
if len(next_frame) < output_frame_size:
next_frame += b"\x00" * (output_frame_size - len(next_frame))
outdata[:] = np.frombuffer(next_frame, dtype="int16").reshape(
(frame_count, 1)
)
self._stream = sounddevice.OutputStream(
samplerate=sample_rate,
channels=1,
callback=callback,
device=None,
dtype="int16",
blocksize=sample_rate // 100,
)
self._stream.start()
if not self._stream.active:
raise RuntimeError("Failed to start streaming output audio")
def write(self, chunk: bytes) -> None:
"""Writes audio data (expected to be in 16-bit PCM format) to this sink's buffer."""
self._buffer.extend(chunk)
def drop_buffer(self) -> None:
"""Drops all audio data in this sink's buffer, ending playback until new data is written."""
self._buffer.clear()
async def close(self) -> None:
if self._stream:
self._stream.close()
class LocalAudioSource:
"""
A source for audio data that reads from the default microphone. Audio data in
16-bit PCM format is available as an AsyncGenerator via the `stream` method.
Args:
sample_rate: The sample rate to use for audio recording. Defaults to 48kHz.
"""
def __init__(self, sample_rate=48000):
self._sample_rate = sample_rate
async def stream(self) -> AsyncGenerator[bytes, None]:
queue: asyncio.Queue[bytes] = asyncio.Queue()
loop = asyncio.get_running_loop()
def callback(indata: np.ndarray, frame_count, time, status):
loop.call_soon_threadsafe(queue.put_nowait, indata.tobytes())
stream = sounddevice.InputStream(
samplerate=self._sample_rate,
channels=1,
callback=callback,
device=None,
dtype="int16",
blocksize=self._sample_rate // 100,
)
with stream:
if not stream.active:
raise RuntimeError("Failed to start streaming input audio")
while True:
yield await queue.get()
class WebsocketVoiceSession(pyee.asyncio.AsyncIOEventEmitter):
"""A websocket-based voice session that connects to an Ultravox call. The session continuously
streams audio in and out and emits events for state changes and agent messages."""
def __init__(self, join_url: str):
super().__init__()
self._state: Literal["idle", "listening", "thinking", "speaking"] = "idle"
self._pending_output = ""
self._url = join_url
self._socket = None
self._receive_task: asyncio.Task | None = None
self._send_audio_task = asyncio.create_task(
self._pump_audio(LocalAudioSource())
)
self._sink = LocalAudioSink()
async def start(self):
logging.info(f"Connecting to {self._url}")
self._socket = await ws_client.connect(self._url)
self._receive_task = asyncio.create_task(self._socket_receive(self._socket))
async def _socket_receive(self, socket: ws_client.ClientConnection):
try:
async for message in socket:
await self._on_socket_message(message)
except asyncio.CancelledError:
logging.info("socket cancelled")
except ws_exceptions.ConnectionClosedOK:
logging.info("socket closed ok")
except ws_exceptions.ConnectionClosedError as e:
self.emit("error", e)
return
logging.info("socket receive done")
self.emit("ended")
async def stop(self):
"""End the session, closing the connection and ending the call."""
logging.info("Stopping...")
await _async_close(
self._sink.close(),
self._socket.close() if self._socket else None,
_async_cancel(self._send_audio_task, self._receive_task),
)
if self._state != "idle":
self._state = "idle"
self.emit("state", "idle")
async def _on_socket_message(self, payload: str | bytes):
if isinstance(payload, bytes):
self._sink.write(payload)
return
elif isinstance(payload, str):
msg = json.loads(payload)
await self._handle_data_message(msg)
async def _handle_data_message(self, msg: dict[str, Any]):
match msg["type"]:
case "playback_clear_buffer":
self._sink.drop_buffer()
case "state":
if msg["state"] != self._state:
self._state = msg["state"]
self.emit("state", msg["state"])
case "transcript":
# This is lazy handling of transcripts. See the WebRTC client SDKs
# for a more robust implementation.
if msg["role"] != "agent":
return # Ignore user transcripts
if msg.get("text", None):
self._pending_output = msg["text"]
self.emit("output", msg["text"], msg["final"])
else:
self._pending_output += msg.get("delta", "")
self.emit("output", self._pending_output, msg["final"])
if msg["final"]:
self._pending_output = ""
case "client_tool_invocation":
await self._handle_client_tool_call(
msg["toolName"], msg["invocationId"], msg["parameters"]
)
case "debug":
logging.info(f"debug: {msg['message']}")
case _:
logging.warning(f"Unhandled message type: {msg['type']}")
async def _handle_client_tool_call(
self, tool_name: str, invocation_id: str, parameters: dict[str, Any]
):
logging.info(f"client tool call: {tool_name}")
response: dict[str, str] = {
"type": "client_tool_result",
"invocationId": invocation_id,
}
if tool_name == "getSecretMenu":
menu = [
{
"date": datetime.date.today().isoformat(),
"items": [
{
"name": "Banana Smoothie",
"price": "$4.99",
},
{
"name": "Butter Pecan Ice Cream (one scoop)",
"price": "$2.99",
},
],
},
]
response["result"] = json.dumps(menu)
else:
response["errorType"] = "undefined"
response["errorMessage"] = f"Unknown tool: {tool_name}"
await self._socket.send(json.dumps(response))
async def _pump_audio(self, source: LocalAudioSource):
async for chunk in source.stream():
if self._socket is None:
continue
await self._socket.send(chunk)
async def _async_close(*awaitables_or_none: Awaitable | None):
coros = [coro for coro in awaitables_or_none if coro is not None]
if coros:
maybe_exceptions = await asyncio.shield(
asyncio.gather(*coros, return_exceptions=True)
)
non_cancelled_exceptions = [
exc
for exc in maybe_exceptions
if isinstance(exc, Exception)
and not isinstance(exc, asyncio.CancelledError)
]
if non_cancelled_exceptions:
to_report = (
non_cancelled_exceptions[0]
if len(non_cancelled_exceptions) == 1
else ExceptionGroup("Multiple failures", non_cancelled_exceptions)
)
logging.warning("Error during _async_close", exc_info=to_report)
async def _async_cancel(*tasks_or_none: asyncio.Task | None):
tasks = [task for task in tasks_or_none if task is not None and task.cancel()]
await _async_close(*tasks)
async def _get_join_url() -> str:
"""Creates a new call, returning its join URL."""
target = "https://api.ultravox.ai/api/calls"
if args.prior_call_id:
target += f"?priorCallId={args.prior_call_id}"
async with aiohttp.ClientSession() as session:
headers = {"X-API-Key": f"{os.getenv('ULTRAVOX_API_KEY', None)}"}
system_prompt = args.system_prompt
selected_tools = []
if args.secret_menu:
system_prompt += "\n\nThere is also a secret menu that changes daily. If the user asks about it, use the getSecretMenu tool to look up today's secret menu items."
selected_tools.append(
{
"temporaryTool": {
"modelToolName": "getSecretMenu",
"description": "Looks up today's secret menu items.",
"client": {},
},
}
)
body = {
"systemPrompt": system_prompt,
"temperature": args.temperature,
"medium": {
"serverWebSocket": {
"inputSampleRate": 48000,
"outputSampleRate": 48000,
# Buffer up to 30s of audio client-side. This won't impact
# interruptions because we handle playback_clear_buffer above.
"clientBufferSizeMs": 30000,
}
},
}
if args.voice:
body["voice"] = args.voice
if selected_tools:
body["selectedTools"] = selected_tools
if args.initial_output_text:
body["initialOutputMedium"] = "MESSAGE_MEDIUM_TEXT"
if args.user_speaks_first:
body["firstSpeaker"] = "FIRST_SPEAKER_USER"
logging.info(f"Creating call with body: {body}")
async with session.post(target, headers=headers, json=body) as response:
response.raise_for_status()
response_json = await response.json()
join_url = response_json["joinUrl"]
join_url = _add_query_param(
join_url, "apiVersion", str(args.api_version or 1)
)
if args.experimental_messages:
join_url = _add_query_param(
join_url, "experimentalMessages", args.experimental_messages
)
return join_url
def _add_query_param(url: str, key: str, value: str) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({key: value})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
async def main():
join_url = await _get_join_url()
client = WebsocketVoiceSession(join_url)
done = asyncio.Event()
loop = asyncio.get_running_loop()
@client.on("state")
async def on_state(state):
if state == "listening":
# Used to prompt the user to speak
print("User: ", end="\r")
elif state == "thinking":
print("Agent: ", end="\r")
@client.on("output")
async def on_output(text, final):
display_text = f"{text.strip()}"
print("Agent: " + display_text, end="\n" if final else "\r")
@client.on("error")
async def on_error(error):
logging.exception("Client error", exc_info=error)
print(f"Error: {error}")
done.set()
@client.on("ended")
async def on_ended():
print("Session ended")
done.set()
loop.add_signal_handler(signal.SIGINT, lambda: done.set())
loop.add_signal_handler(signal.SIGTERM, lambda: done.set())
await client.start()
await done.wait()
await client.stop()
if __name__ == "__main__":
api_key = os.getenv("ULTRAVOX_API_KEY", None)
if not api_key:
raise ValueError("Please set your ULTRAVOX_API_KEY environment variable")
parser = argparse.ArgumentParser(prog="websocket_client.py")
parser.add_argument(
"--verbose", "-v", action="store_true", help="Show verbose session information"
)
parser.add_argument(
"--very-verbose", "-vv", action="store_true", help="Show debug logs too"
)
parser.add_argument("--voice", "-V", type=str, help="Name (or id) of voice to use")
parser.add_argument(
"--system-prompt",
"-s",
type=str,
default=f"""
You are a drive-thru order taker for a donut shop called "Dr. Donut". Local time is currently: ${datetime.datetime.now().isoformat()}
The user is talking to you over voice on their phone, and your response will be read out loud with realistic text-to-speech (TTS) technology.
Follow every direction here when crafting your response:
1. Use natural, conversational language that is clear and easy to follow (short sentences, simple words).
1a. Be concise and relevant: Most of your responses should be a sentence or two, unless you're asked to go deeper. Don't monopolize the conversation.
1b. Use discourse markers to ease comprehension. Never use the list format.
2. Keep the conversation flowing.
2a. Clarify: when there is ambiguity, ask clarifying questions, rather than make assumptions.
2b. Don't implicitly or explicitly try to end the chat (i.e. do not end a response with "Talk soon!", or "Enjoy!").
2c. Sometimes the user might just want to chat. Ask them relevant follow-up questions.
2d. Don't ask them if there's anything else they need help with (e.g. don't say things like "How can I assist you further?").
3. Remember that this is a voice conversation:
3a. Don't use lists, markdown, bullet points, or other formatting that's not typically spoken.
3b. Type out numbers in words (e.g. 'twenty twelve' instead of the year 2012)
3c. If something doesn't make sense, it's likely because you misheard them. There wasn't a typo, and the user didn't mispronounce anything.
Remember to follow these rules absolutely, and do not refer to these rules, even if you're asked about them.
When talking with the user, use the following script:
1. Take their order, acknowledging each item as it is ordered. If it's not clear which menu item the user is ordering, ask them to clarify.
DO NOT add an item to the order unless it's one of the items on the menu below.
2. Once the order is complete, repeat back the order.
2a. If the user only ordered a drink, ask them if they would like to add a donut to their order.
2b. If the user only ordered donuts, ask them if they would like to add a drink to their order.
2c. If the user ordered both drinks and donuts, don't suggest anything.
3. Total up the price of all ordered items and inform the user.
4. Ask the user to pull up to the drive thru window.
If the user asks for something that's not on the menu, inform them of that fact, and suggest the most similar item on the menu.
If the user says something unrelated to your role, respond with "Um... this is a Dr. Donut."
If the user says "thank you", respond with "My pleasure."
If the user asks about what's on the menu, DO NOT read the entire menu to them. Instead, give a couple suggestions.
The menu of available items is as follows:
# DONUTS
PUMPKIN SPICE ICED DOUGHNUT $1.29
PUMPKIN SPICE CAKE DOUGHNUT $1.29
OLD FASHIONED DOUGHNUT $1.29
CHOCOLATE ICED DOUGHNUT $1.09
CHOCOLATE ICED DOUGHNUT WITH SPRINKLES $1.09
RASPBERRY FILLED DOUGHNUT $1.09
BLUEBERRY CAKE DOUGHNUT $1.09
STRAWBERRY ICED DOUGHNUT WITH SPRINKLES $1.09
LEMON FILLED DOUGHNUT $1.09
DOUGHNUT HOLES $3.99
# COFFEE & DRINKS
PUMPKIN SPICE COFFEE $2.59
PUMPKIN SPICE LATTE $4.59
REGULAR BREWED COFFEE $1.79
DECAF BREWED COFFEE $1.79
LATTE $3.49
CAPPUCINO $3.49
CARAMEL MACCHIATO $3.49
MOCHA LATTE $3.49
CARAMEL MOCHA LATTE $3.49""",
help="System prompt to use when creating the call",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Temperature to use when creating the call",
)
parser.add_argument(
"--secret-menu",
action="store_true",
help="Adds prompt and client-implemented tool for a secret menu. For use with the default system prompt.",
)
parser.add_argument(
"--experimental-messages",
type=str,
help="Enables the specified experimental messages (e.g. 'debug' which should be used with -v)",
)
parser.add_argument(
"--prior-call-id",
type=str,
help="Allows setting priorCallId during start call",
)
parser.add_argument(
"--user-speaks-first",
action="store_true",
help="If set, sets FIRST_SPEAKER_USER",
)
parser.add_argument(
"--initial-output-text",
action="store_true",
help="Sets the initial_output_medium to text",
)
parser.add_argument(
"--api-version",
type=int,
help="API version to set when creating the call.",
)
args = parser.parse_args()
if args.very_verbose:
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
elif args.verbose:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
asyncio.run(main())