diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 1e615b0..9cd774b 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -1,6 +1,6 @@ from .audio.rooms import CreateRoomResp from .audio.speech import AudioFormat -from .audio.transcriptions import CreateTranslationResp +from .audio.transcriptions import CreateTranscriptionsResp from .audio.voices import Voice from .auth import ( AsyncDeviceOAuthApp, @@ -92,27 +92,27 @@ from .templates import TemplateDuplicateResp, TemplateEntityType from .version import VERSION from .websockets.audio.speech import ( - AsyncWebsocketsAudioSpeechCreateClient, + AsyncWebsocketsAudioSpeechClient, AsyncWebsocketsAudioSpeechEventHandler, InputTextBufferAppendEvent, - InputTextBufferCommitEvent, - InputTextBufferCommittedEvent, + InputTextBufferCompletedEvent, + InputTextBufferCompleteEvent, SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, SpeechUpdateEvent, ) from .websockets.audio.transcriptions import ( - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, InputAudioBufferAppendEvent, - InputAudioBufferCommitEvent, - InputAudioBufferCommittedEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, TranscriptionsMessageCompletedEvent, TranscriptionsMessageUpdateEvent, TranscriptionsUpdateEvent, ) from .websockets.chat import ( - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, ConversationAudioDeltaEvent, ConversationChatCompletedEvent, @@ -143,7 +143,7 @@ "Voice", "AudioFormat", # audio.transcriptions - "CreateTranslationResp", + "CreateTranscriptionsResp", # auth "AsyncDeviceOAuthApp", "AsyncJWTOAuthApp", @@ -217,29 +217,29 @@ "WebsocketsEvent", # websockets.audio.speech "InputTextBufferAppendEvent", - "InputTextBufferCommitEvent", + "InputTextBufferCompleteEvent", "SpeechUpdateEvent", - "InputTextBufferCommittedEvent", + "InputTextBufferCompletedEvent", "SpeechAudioUpdateEvent", "SpeechAudioCompletedEvent", "AsyncWebsocketsAudioSpeechEventHandler", - "AsyncWebsocketsAudioSpeechCreateClient", + "AsyncWebsocketsAudioSpeechClient", # websockets.audio.transcriptions "InputAudioBufferAppendEvent", - "InputAudioBufferCommitEvent", + "InputAudioBufferCompleteEvent", "TranscriptionsUpdateEvent", - "InputAudioBufferCommittedEvent", + "InputAudioBufferCompletedEvent", "TranscriptionsMessageUpdateEvent", "TranscriptionsMessageCompletedEvent", "AsyncWebsocketsAudioTranscriptionsEventHandler", - "AsyncWebsocketsAudioTranscriptionsCreateClient", + "AsyncWebsocketsAudioTranscriptionsClient", # websockets.chat "ConversationChatCreatedEvent", "ConversationMessageDeltaEvent", "ConversationAudioDeltaEvent", "ConversationChatCompletedEvent", "AsyncWebsocketsChatEventHandler", - "AsyncWebsocketsChatCreateClient", + "AsyncWebsocketsChatClient", # workflows.runs "WorkflowRunResult", "WorkflowEventType", diff --git a/cozepy/audio/transcriptions/__init__.py b/cozepy/audio/transcriptions/__init__.py index 7f37181..bac1f82 100644 --- a/cozepy/audio/transcriptions/__init__.py +++ b/cozepy/audio/transcriptions/__init__.py @@ -7,7 +7,7 @@ from cozepy.util import remove_url_trailing_slash -class CreateTranslationResp(CozeModel): +class CreateTranscriptionsResp(CozeModel): # The text of translation text: str @@ -23,18 +23,18 @@ def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" headers: Optional[dict] = kwargs.get("headers") files = {"file": _try_fix_file(file)} return self._requester.request( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) @@ -53,16 +53,16 @@ async def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" files = {"file": _try_fix_file(file)} headers: Optional[dict] = kwargs.get("headers") return await self._requester.arequest( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) diff --git a/cozepy/websockets/__init__.py b/cozepy/websockets/__init__.py index d433773..6353636 100644 --- a/cozepy/websockets/__init__.py +++ b/cozepy/websockets/__init__.py @@ -2,8 +2,31 @@ from cozepy.request import Requester from cozepy.util import http_base_url_to_ws, remove_url_trailing_slash -from .audio import AsyncWebsocketsAudioClient -from .chat import AsyncWebsocketsChatClient +from .audio import AsyncWebsocketsAudioClient, WebsocketsAudioClient +from .chat import AsyncWebsocketsChatBuildClient, WebsocketsChatBuildClient + + +class WebsocketsClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = http_base_url_to_ws(remove_url_trailing_slash(base_url)) + self._auth = auth + self._requester = requester + + @property + def audio(self) -> WebsocketsAudioClient: + return WebsocketsAudioClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def chat(self) -> WebsocketsChatBuildClient: + return WebsocketsChatBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsClient(object): @@ -21,8 +44,8 @@ def audio(self) -> AsyncWebsocketsAudioClient: ) @property - def chat(self) -> AsyncWebsocketsChatClient: - return AsyncWebsocketsChatClient( + def chat(self) -> AsyncWebsocketsChatBuildClient: + return AsyncWebsocketsChatBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/__init__.py b/cozepy/websockets/audio/__init__.py index 3f05e92..a02f3ea 100644 --- a/cozepy/websockets/audio/__init__.py +++ b/cozepy/websockets/audio/__init__.py @@ -1,8 +1,31 @@ from cozepy.auth import Auth from cozepy.request import Requester -from .speech import AsyncWebsocketsAudioSpeechClient -from .transcriptions import AsyncWebsocketsAudioTranscriptionsClient +from .speech import AsyncWebsocketsAudioSpeechBuildClient, WebsocketsAudioSpeechBuildClient +from .transcriptions import AsyncWebsocketsAudioTranscriptionsBuildClient, WebsocketsAudioTranscriptionsBuildClient + + +class WebsocketsAudioClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = base_url + self._auth = auth + self._requester = requester + + @property + def transcriptions(self) -> "WebsocketsAudioTranscriptionsBuildClient": + return WebsocketsAudioTranscriptionsBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def speech(self) -> "WebsocketsAudioSpeechBuildClient": + return WebsocketsAudioSpeechBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsAudioClient(object): @@ -12,16 +35,16 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester @property - def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsClient": - return AsyncWebsocketsAudioTranscriptionsClient( + def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsBuildClient": + return AsyncWebsocketsAudioTranscriptionsBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, ) @property - def speech(self) -> "AsyncWebsocketsAudioSpeechClient": - return AsyncWebsocketsAudioSpeechClient( + def speech(self) -> "AsyncWebsocketsAudioSpeechBuildClient": + return AsyncWebsocketsAudioSpeechBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/speech/__init__.py b/cozepy/websockets/audio/speech/__init__.py index 8190d10..2796987 100644 --- a/cozepy/websockets/audio/speech/__init__.py +++ b/cozepy/websockets/audio/speech/__init__.py @@ -10,6 +10,9 @@ from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + OutputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -20,42 +23,27 @@ class InputTextBufferAppendEvent(WebsocketsEvent): class Data(BaseModel): delta: str - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_APPEND + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_APPEND data: Data # req -class InputTextBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMIT +class InputTextBufferCompleteEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETE # req class SpeechUpdateEvent(WebsocketsEvent): - class OpusConfig(object): - bitrate: Optional[int] = None - use_cbr: Optional[bool] = None - frame_size_ms: Optional[float] = None - - class PCMConfig(object): - sample_rate: Optional[int] = None - - class OutputAudio(object): - codec: Optional[str] - pcm_config: Optional["SpeechUpdateEvent.PCMConfig"] - opus_config: Optional["SpeechUpdateEvent.OpusConfig"] - speech_rate: Optional[int] - voice_id: Optional[str] - - class Data: - output_audio: "SpeechUpdateEvent.OutputAudio" + class Data(BaseModel): + output_audio: Optional[OutputAudio] = None - type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE data: Data # resp -class InputTextBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED +class InputTextBufferCompletedEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED # resp @@ -67,91 +55,199 @@ class Data(BaseModel): def serialize_delta(self, delta: bytes, _info): return base64.b64encode(delta) - type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_UPDATE data: Data # resp class SpeechAudioCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_COMPLETED -class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_text_buffer_committed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: InputTextBufferCommittedEvent - ): +class WebsocketsAudioSpeechEventHandler(WebsocketsBaseEventHandler): + def on_input_text_buffer_completed(self, cli: "WebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent): pass - async def on_speech_audio_update( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioUpdateEvent - ): + def on_speech_audio_update(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): pass - async def on_speech_audio_completed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioCompletedEvent - ): + def on_speech_audio_completed(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent): pass -class AsyncWebsocketsAudioSpeechCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioSpeechClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioSpeechEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: on_event.on_input_text_buffer_committed, - WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, - WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, - } + if isinstance(on_event, WebsocketsAudioSpeechEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) super().__init__( base_url=base_url, auth=auth, requester=requester, path="v1/audio/speech", - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.SPEECH_AUDIO_COMPLETED], **kwargs, ) - async def append(self, text: str) -> None: - await self._input_queue.put( - InputTextBufferAppendEvent.model_validate( + def input_text_buffer_append(self, data: InputTextBufferAppendEvent.Data) -> None: + self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + def input_text_buffer_complete(self) -> None: + self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) + + def speech_update(self, event: SpeechUpdateEvent) -> None: + self._input_queue.put(event) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) + if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: + delta_base64 = data.get("delta") + if delta_base64 is None: + raise ValueError("Missing 'delta' in event data") + return SpeechAudioUpdateEvent.model_validate( { - "data": InputTextBufferAppendEvent.Data.model_validate( + "id": event_id, + "detail": detail, + "data": SpeechAudioUpdateEvent.Data.model_validate( { - "delta": text, + "delta": base64.b64decode(delta_base64), } - ) + ), + } + ) + elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: + return SpeechAudioCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, } ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) + return None + + +class WebsocketsAudioSpeechBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, *, on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs + ) -> WebsocketsAudioSpeechClient: + return WebsocketsAudioSpeechClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, # type: ignore + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputTextBufferCommitEvent.model_validate({})) - async def update(self, event: SpeechUpdateEvent) -> None: - await self._input_queue.put(event) +class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + pass + + async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): + pass + + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioSpeechClient(AsyncWebsocketsBaseClient): + class EventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + pass + + async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): + pass + + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + pass + + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union["AsyncWebsocketsAudioSpeechClient.EventHandler", Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioSpeechClient.EventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/speech", + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.SPEECH_AUDIO_COMPLETED], + **kwargs, + ) + + async def input_text_buffer_append(self, data: InputTextBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + async def input_text_buffer_complete(self) -> None: + await self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) + + async def speech_update(self, data: SpeechUpdateEvent.Data) -> None: + await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: - return InputTextBufferCommittedEvent.model_validate({"event_id": event_id}) + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate({"id": event_id, "detail": detail}) if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: delta_base64 = data.get("delta") if delta_base64 is None: raise ValueError("Missing 'delta' in event data") return SpeechAudioUpdateEvent.model_validate( { - "event_id": event_id, + "id": event_id, + "detail": detail, "data": SpeechAudioUpdateEvent.Data.model_validate( { "delta": base64.b64decode(delta_base64), @@ -160,25 +256,33 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: } ) elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: - return SpeechAudioCompletedEvent.model_validate({"event_id": event_id}) + return SpeechAudioCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None -class AsyncWebsocketsAudioSpeechClient: +class AsyncWebsocketsAudioSpeechBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, *, on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs - ) -> AsyncWebsocketsAudioSpeechCreateClient: - return AsyncWebsocketsAudioSpeechCreateClient( + self, + *, + on_event: Union[AsyncWebsocketsAudioSpeechClient.EventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> AsyncWebsocketsAudioSpeechClient: + return AsyncWebsocketsAudioSpeechClient( base_url=self._base_url, auth=self._auth, requester=self._requester, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/audio/transcriptions/__init__.py b/cozepy/websockets/audio/transcriptions/__init__.py index e6dfb4c..cdc5c48 100644 --- a/cozepy/websockets/audio/transcriptions/__init__.py +++ b/cozepy/websockets/audio/transcriptions/__init__.py @@ -5,12 +5,14 @@ from cozepy.auth import Auth from cozepy.log import log_warning -from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + InputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -25,34 +27,27 @@ class Data(BaseModel): def serialize_delta(self, delta: bytes, _info): return base64.b64encode(delta) - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND data: Data # req -class InputAudioBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMIT +class InputAudioBufferCompleteEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETE # req class TranscriptionsUpdateEvent(WebsocketsEvent): - class InputAudio(CozeModel): - format: Optional[str] - codec: Optional[str] - sample_rate: Optional[int] - channel: Optional[int] - bit_depth: Optional[int] - class Data(BaseModel): - input_audio: "TranscriptionsUpdateEvent.InputAudio" + input_audio: Optional[InputAudio] = None - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_UPDATE data: Data # resp -class InputAudioBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED +class InputAudioBufferCompletedEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED # resp @@ -60,92 +55,192 @@ class TranscriptionsMessageUpdateEvent(WebsocketsEvent): class Data(BaseModel): content: str - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE data: Data # resp class TranscriptionsMessageCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED -class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_audio_buffer_committed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: InputAudioBufferCommittedEvent +class WebsocketsAudioTranscriptionsEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): pass - async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + def on_transcriptions_message_update( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): pass - async def on_transcriptions_message_completed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageCompletedEvent + def on_transcriptions_message_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent ): pass -class AsyncWebsocketsAudioTranscriptionsCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioTranscriptionsClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED: on_event.on_input_audio_buffer_committed, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, - } + if isinstance(on_event, WebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) super().__init__( base_url=base_url, auth=auth, requester=requester, path="v1/audio/transcriptions", - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED], **kwargs, ) - async def update(self, data: TranscriptionsUpdateEvent.InputAudio) -> None: - await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + def transcriptions_update(self, data: TranscriptionsUpdateEvent.Data) -> None: + self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( { - "data": InputAudioBufferAppendEvent.Data.model_validate( + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE.value: + return TranscriptionsMessageUpdateEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": TranscriptionsMessageUpdateEvent.Data.model_validate( { - "delta": delta, + "content": data.get("content") or "", } - ) + ), + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, } ) + else: + log_warning("[v1/audio/transcriptions] unknown event=%s, logid=%s", event_type, detail.logid) + return None + + +class WebsocketsAudioTranscriptionsBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsAudioTranscriptionsClient: + return WebsocketsAudioTranscriptionsClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, # type: ignore + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + +class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_transcriptions_message_update( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + ): + pass + + async def on_transcriptions_message_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioTranscriptionsClient(AsyncWebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/transcriptions", + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED], + **kwargs, + ) + + async def transcriptions_update(self, data: InputAudio) -> None: + await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + event_type = message.get("event_type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: - return TranscriptionsMessageCompletedEvent.model_validate( + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( { - "event_id": event_id, + "id": event_id, } ) elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE.value: return TranscriptionsMessageUpdateEvent.model_validate( { - "event_id": event_id, + "id": event_id, "data": TranscriptionsMessageUpdateEvent.Data.model_validate( { "content": data.get("content") or "", @@ -153,14 +248,18 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ), } ) - elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED.value: - pass + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "id": event_id, + } + ) else: log_warning("[v1/audio/transcriptions] unknown event=%s", event_type) return None -class AsyncWebsocketsAudioTranscriptionsClient: +class AsyncWebsocketsAudioTranscriptionsBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -171,11 +270,11 @@ def create( *, on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsAudioTranscriptionsCreateClient: - return AsyncWebsocketsAudioTranscriptionsCreateClient( + ) -> AsyncWebsocketsAudioTranscriptionsClient: + return AsyncWebsocketsAudioTranscriptionsClient( base_url=self._base_url, auth=self._auth, requester=self._requester, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 1bcb1b9..35b3045 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -1,66 +1,264 @@ -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Union -from cozepy import Chat, Message +from pydantic import BaseModel + +from cozepy import Chat, Message, ToolOutput from cozepy.auth import Auth from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash -from cozepy.websockets.audio.transcriptions import InputAudioBufferAppendEvent, InputAudioBufferCommitEvent +from cozepy.websockets.audio.transcriptions import ( + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, +) from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + InputAudio, + OutputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) +# req +class ChatUpdateEvent(WebsocketsEvent): + class ChatConfig(BaseModel): + conversation_id: Optional[str] = None + user_id: Optional[str] = None + meta_data: Optional[Dict[str, str]] = None + custom_variables: Optional[Dict[str, str]] = None + extra_params: Optional[Dict[str, str]] = None + auto_save_history: Optional[bool] = None + + class Data(BaseModel): + output_audio: Optional[OutputAudio] = None + input_audio: Optional[InputAudio] = None + chat_config: Optional["ChatUpdateEvent.ChatConfig"] = None + + event_type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATE + + +# req +class ConversationChatSubmitToolOutputsEvent(WebsocketsEvent): + class Data(BaseModel): + chat_id: str + tool_outputs: List[ToolOutput] + + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS + data: Data + + # resp class ConversationChatCreatedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED data: Chat # resp class ConversationMessageDeltaEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_MESSAGE_DELTA + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_MESSAGE_DELTA data: Message +# resp +class ConversationChatRequiresActionEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION + data: Chat + + # resp class ConversationAudioDeltaEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_DELTA + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_DELTA data: Message # resp class ConversationChatCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_COMPLETED data: Chat +class WebsocketsChatEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent): + pass + + def on_conversation_chat_created(self, cli: "WebsocketsChatClient", event: ConversationChatCreatedEvent): + pass + + def on_conversation_message_delta(self, cli: "WebsocketsChatClient", event: ConversationMessageDeltaEvent): + pass + + def on_conversation_chat_requires_action( + self, cli: "WebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass + + def on_conversation_audio_delta(self, cli: "WebsocketsChatClient", event: ConversationAudioDeltaEvent): + pass + + def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: ConversationChatCompletedEvent): + pass + + +class WebsocketsChatClient(WebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, WebsocketsChatEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/chat", + query={ + "bot_id": bot_id, + }, + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], + **kwargs, + ) + + def chat_update(self, data: ChatUpdateEvent.Data) -> None: + self._input_queue.put(ChatUpdateEvent.model_validate({"data": data})) + + def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data): + self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) + + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + return ConversationChatCreatedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: + return ConversationMessageDeltaEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: + return ConversationAudioDeltaEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: + return ConversationChatCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) + return None + + +class WebsocketsChatBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsChatClient: + return WebsocketsChatClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + bot_id=bot_id, + on_event=on_event, # type: ignore + **kwargs, + ) + + class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler): - def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCreatedEvent): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCreatedEvent): pass - def on_conversation_message_delta( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationMessageDeltaEvent + async def on_conversation_message_delta( + self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageDeltaEvent ): pass - # def on_conversation_chat_requires_action(self, cli: 'AsyncWebsocketsChatCreateClient', - # event: ConversationChatRequiresActionEvent): - # pass + async def on_conversation_chat_requires_action( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass - def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationAudioDeltaEvent): + async def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioDeltaEvent): pass - def on_conversation_chat_completed( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCompletedEvent + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent ): pass -class AsyncWebsocketsChatCreateClient(AsyncWebsocketsBaseClient): +class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient): def __init__( self, base_url: str, @@ -71,15 +269,16 @@ def __init__( **kwargs, ): if isinstance(on_event, AsyncWebsocketsChatEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, - WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, - # EventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, - WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, - WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, - } + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) super().__init__( base_url=base_url, auth=auth, @@ -88,68 +287,81 @@ def __init__( query={ "bot_id": bot_id, }, - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], **kwargs, ) - # async def update(self, event: TranscriptionsUpdateEventInputAudio) -> None: - # await self._input_queue.put(TranscriptionsUpdateEvent.load(event)) + async def chat_update(self, data: ChatUpdateEvent.Data) -> None: + await self._input_queue.put(ChatUpdateEvent.model_validate({"data": data})) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( - { - "data": InputAudioBufferAppendEvent.Data.model_validate( - { - "delta": delta, - } - ) - } - ) - ) + async def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data) -> None: + await self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: return ConversationChatCreatedEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Chat.model_validate(data), - "event_id": event_id, } ) if event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: return ConversationMessageDeltaEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Message.model_validate(data), - "event_id": event_id, + } + ) + if event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), } ) elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: return ConversationAudioDeltaEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Message.model_validate(data), - "event_id": event_id, } ) elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: return ConversationChatCompletedEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Chat.model_validate(data), - "event_id": event_id, } ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None -class AsyncWebsocketsChatClient(object): +class AsyncWebsocketsChatBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -161,12 +373,12 @@ def create( bot_id: str, on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsChatCreateClient: - return AsyncWebsocketsChatCreateClient( + ) -> AsyncWebsocketsChatClient: + return AsyncWebsocketsChatClient( base_url=self._base_url, auth=self._auth, requester=self._requester, bot_id=bot_id, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 9e787a6..6101aec 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -1,17 +1,21 @@ import abc import asyncio import json +import queue +import threading +import traceback from abc import ABC -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, contextmanager from enum import Enum -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Set -import websockets +import websockets.asyncio.client +import websockets.sync.client +from pydantic import BaseModel from websockets import InvalidStatus -from websockets.asyncio.connection import Connection from cozepy import Auth, CozeAPIError -from cozepy.log import log_debug, log_info +from cozepy.log import log_debug, log_error, log_info from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -20,49 +24,307 @@ class WebsocketsEventType(str, Enum): # common - ERROR = "error" - CLOSED = "closed" + CLIENT_ERROR = "client_error" # sdk error + CLOSED = "closed" # connection closed + + # error + ERROR = "error" # received error event # v1/audio/speech # req - INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" - INPUT_TEXT_BUFFER_COMMIT = "input_text_buffer.commit" - SPEECH_UPDATE = "speech.update" + INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" # send text to server + INPUT_TEXT_BUFFER_COMPLETE = ( + "input_text_buffer.complete" # no text to send, after audio all received, can close connection + ) + SPEECH_UPDATE = "speech.update" # send speech config to server # resp # v1/audio/speech - INPUT_TEXT_BUFFER_COMMITTED = "input_text_buffer.committed" # ignored - SPEECH_AUDIO_UPDATE = "speech.audio.update" - SPEECH_AUDIO_COMPLETED = "speech.audio.completed" + INPUT_TEXT_BUFFER_COMPLETED = "input_text_buffer.completed" # received `input_text_buffer.complete` event + SPEECH_AUDIO_UPDATE = "speech.audio.update" # received `speech.update` event + SPEECH_AUDIO_COMPLETED = "speech.audio.completed" # all audio received, can close connection # v1/audio/transcriptions # req - INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - TRANSCRIPTIONS_UPDATE = "transcriptions.update" + INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + INPUT_AUDIO_BUFFER_COMPLETE = ( + "input_audio_buffer.complete" # no audio to send, after text all received, can close connection + ) + TRANSCRIPTIONS_UPDATE = "transcriptions.update" # send transcriptions config to server # resp - INPUT_AUDIO_BUFFER_COMMITTED = "input_audio_buffer.committed" # ignored - TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" - TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" + INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" # received `transcriptions.update` event + TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" # all audio received, can close connection # v1/chat # req - # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - # INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - CHAT_UPDATE = "chat.update" + # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + # INPUT_AUDIO_BUFFER_COMPLETE = "input_audio_buffer.complete" # no audio send, start chat + CHAT_UPDATE = "chat.update" # send chat config to server + CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server # resp - CONVERSATION_CHAT_CREATED = "conversation.chat.created" - CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" - CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" - CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" - CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" + # INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started + CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" # get agent text message update + CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" # need plugin submit + CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update + CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" # all message received, can close connection class WebsocketsEvent(CozeModel, ABC): + class Detail(BaseModel): + logid: Optional[str] = None + + event_type: WebsocketsEventType event_id: Optional[str] = None - type: WebsocketsEventType + detail: Optional[Detail] = None + + +class WebsocketsErrorEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.ERROR + data: CozeAPIError + + +class InputAudio(CozeModel): + format: Optional[str] + codec: Optional[str] + sample_rate: Optional[int] + channel: Optional[int] + bit_depth: Optional[int] + + +class OpusConfig(BaseModel): + bitrate: Optional[int] = None + use_cbr: Optional[bool] = None + frame_size_ms: Optional[float] = None + + +class PCMConfig(BaseModel): + sample_rate: Optional[int] = None + + +class OutputAudio(BaseModel): + codec: Optional[str] + pcm_config: Optional[PCMConfig] = None + opus_config: Optional[OpusConfig] = None + speech_rate: Optional[int] = None + voice_id: Optional[str] = None + + +class WebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + path: str, + query: Optional[Dict[str, str]] = None, + on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, + wait_events: Optional[List[WebsocketsEventType]] = None, + **kwargs, + ): + self._state = self.State.INITIALIZED + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + self._path = path + self._ws_url = self._base_url + "/" + path + if query: + self._ws_url += "?" + "&".join([f"{k}={v}" for k, v in query.items()]) + self._on_event = on_event.copy() if on_event else {} + self._headers = kwargs.get("headers") + self._wait_events = wait_events.copy() if wait_events else [] + + self._input_queue: queue.Queue[Optional[WebsocketsEvent]] = queue.Queue() + self._ws: Optional[websockets.sync.client.ClientConnection] = None + self._send_thread: Optional[threading.Thread] = None + self._receive_thread: Optional[threading.Thread] = None + self._completed_events: Set[WebsocketsEventType] = set() + self._completed_event = threading.Event() + + @contextmanager + def __call__(self): + try: + self.connect() + yield self + finally: + self.close() + + def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING + headers = { + "Authorization": f"Bearer {self._auth.token}", + "X-Coze-Client-User-Agent": coze_client_user_agent(), + **(self._headers or {}), + } + try: + self._ws = websockets.sync.client.connect( + self._ws_url, + user_agent_header=user_agent(), + additional_headers=headers, + ) + self._state = self.State.CONNECTED + log_info("[%s] connected to websocket", self._path) + + self._send_thread = threading.Thread(target=self._send_loop) + self._receive_thread = threading.Thread(target=self._receive_loop) + self._send_thread.start() + self._receive_thread.start() + except InvalidStatus as e: + raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e + + def wait(self, events: Optional[List[WebsocketsEventType]] = None, wait_all=True) -> None: + if events is None: + events = self._wait_events + self._wait_completed(events, wait_all=wait_all) + + def on(self, event_type: WebsocketsEventType, handler: Callable): + self._on_event[event_type] = handler + + def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING + self._close() + self._state = self.State.CLOSED + + def _send_loop(self) -> None: + try: + while True: + event = self._input_queue.get() + self._send_event(event) + self._input_queue.task_done() + except Exception as e: + self._handle_error(e) + + def _receive_loop(self) -> None: + try: + while True: + if not self._ws: + log_debug("[%s] empty websocket conn, close", self._path) + break + + data = self._ws.recv() + message = json.loads(data) + event_type = message.get("event_type") + log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) + + if handler := self._on_event.get(event_type): + if event := self._load_all_event(message): + handler(self, event) + self._completed_events.add(event_type) + self._completed_event.set() + except Exception as e: + self._handle_error(e) + + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": CozeAPIError(code, msg, WebsocketsEvent.Detail.model_validate(detail).logid), + } + ) + return self._load_event(message) + + @abc.abstractmethod + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... + + def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + while True: + if wait_all: + # 所有事件都需要完成 + if self._completed_events == set(events): + break + else: + # 任意一个事件完成即可 + if any(event in self._completed_events for event in events): + break + self._completed_event.wait() + self._completed_event.clear() + + def _handle_error(self, error: Exception) -> None: + if handler := self._on_event.get(WebsocketsEventType.ERROR): + handler(self, error) + else: + raise error + + def _close(self) -> None: + log_info("[%s] connect closed", self._path) + if self._send_thread: + self._send_thread.join() + if self._receive_thread: + self._receive_thread.join() + + if self._ws: + self._ws.close() + self._ws = None + + while not self._input_queue.empty(): + self._input_queue.get() + + if handler := self._on_event.get(WebsocketsEventType.CLOSED): + handler(self) + + def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None: + if not event or not self._ws: + return + log_debug("[%s] send event, type=%s", self._path, event.event_type.value) + self._ws.send(event.model_dump_json()) + + +class WebsocketsBaseEventHandler(object): + def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_closed(self, cli: "WebsocketsBaseClient"): + pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res class AsyncWebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + def __init__( self, base_url: str, @@ -74,6 +336,7 @@ def __init__( wait_events: Optional[List[WebsocketsEventType]] = None, **kwargs, ): + self._state = self.State.INITIALIZED self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester @@ -86,7 +349,7 @@ def __init__( self._wait_events = wait_events.copy() if wait_events else [] self._input_queue: asyncio.Queue[Optional[WebsocketsEvent]] = asyncio.Queue() - self._ws: Optional[Connection] = None + self._ws: Optional[websockets.asyncio.client.ClientConnection] = None self._send_task: Optional[asyncio.Task] = None self._receive_task: Optional[asyncio.Task] = None @@ -99,21 +362,25 @@ async def __call__(self): await self.close() async def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING headers = { "Authorization": f"Bearer {self._auth.token}", "X-Coze-Client-User-Agent": coze_client_user_agent(), **(self._headers or {}), } try: - self._ws = await websockets.connect( + self._ws = await websockets.asyncio.client.connect( self._ws_url, user_agent_header=user_agent(), additional_headers=headers, ) + self._state = self.State.CONNECTED log_info("[%s] connected to websocket", self._path) - self._receive_task = asyncio.create_task(self._receive_loop()) self._send_task = asyncio.create_task(self._send_loop()) + self._receive_task = asyncio.create_task(self._receive_loop()) except InvalidStatus as e: raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e @@ -126,13 +393,17 @@ def on(self, event_type: WebsocketsEventType, handler: Callable): self._on_event[event_type] = handler async def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING await self._close() + self._state = self.State.CLOSED async def _send_loop(self) -> None: try: while True: - if event := await self._input_queue.get(): - await self._send_event(event) + event = await self._input_queue.get() + await self._send_event(event) self._input_queue.task_done() except Exception as e: await self._handle_error(e) @@ -146,50 +417,73 @@ async def _receive_loop(self) -> None: data = await self._ws.recv() message = json.loads(data) - event_type = message.get("type") + event_type = message.get("event_type") log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) if handler := self._on_event.get(event_type): - if event := self._load_event(message): + if event := self._load_all_event(message): await handler(self, event) except Exception as e: await self._handle_error(e) + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": CozeAPIError(code, msg, WebsocketsEvent.Detail.model_validate(detail).logid), + } + ) + return self._load_event(message) + @abc.abstractmethod def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... - async def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + async def _wait_completed(self, wait_events: List[WebsocketsEventType], wait_all: bool) -> None: future: asyncio.Future[None] = asyncio.Future() - original_handlers = {} completed_events = set() - async def _handle_completed(client, event): - event_type = event.type - completed_events.add(event_type) + def _wrap_handler(event_type: WebsocketsEventType, original_handler): + async def wrapped(client, event): + # 先执行原始处理函数 + if original_handler: + await original_handler(client, event) + + # 再检查完成条件 + completed_events.add(event_type) + if wait_all: + # 所有事件都需要完成 + if completed_events == set(wait_events): + if not future.done(): + future.set_result(None) + else: + # 任意一个事件完成即可 + if not future.done(): + future.set_result(None) - if wait_all: - # 所有事件都需要完成 - if completed_events == set(events): - future.set_result(None) - else: - # 任意一个事件完成即可 - future.set_result(None) + return wrapped - # 为每个指定的事件类型临时注册处理函数 - for event_type in events: - original_handlers[event_type] = self._on_event.get(event_type) - self._on_event[event_type] = _handle_completed + # 为每个指定的事件类型包装处理函数 + origin_handlers = {} + for event_type in wait_events: + original_handler = self._on_event.get(event_type) + origin_handlers[event_type] = original_handler + self._on_event[event_type] = _wrap_handler(event_type, original_handler) try: # 等待直到满足完成条件 await future finally: # 恢复所有原来的处理函数 - for event_type, handler in original_handlers.items(): - if handler: - self._on_event[event_type] = handler - else: - self._on_event.pop(event_type, None) + for event_type in wait_events: + if original_handler := origin_handlers.get(event_type): + self._on_event[event_type] = original_handler async def _handle_error(self, error: Exception) -> None: if handler := self._on_event.get(WebsocketsEventType.ERROR): @@ -214,15 +508,30 @@ async def _close(self) -> None: if handler := self._on_event.get(WebsocketsEventType.CLOSED): await handler(self) - async def _send_event(self, event: WebsocketsEvent) -> None: - log_debug("[%s] send event, type=%s", self._path, event.type.value) - if self._ws: - await self._ws.send(event.model_dump_json(), True) + async def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None: + if not event or not self._ws: + return + log_debug("[%s] send event, type=%s", self._path, event.event_type.value) + await self._ws.send(event.model_dump_json(), True) class AsyncWebsocketsBaseEventHandler(object): - async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: Exception): - pass + async def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: CozeAPIError): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") async def on_closed(self, cli: "AsyncWebsocketsBaseClient"): pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res diff --git a/cozepy/workflows/runs/__init__.py b/cozepy/workflows/runs/__init__.py index d9dec93..a774941 100644 --- a/cozepy/workflows/runs/__init__.py +++ b/cozepy/workflows/runs/__init__.py @@ -253,7 +253,7 @@ def resume( url = f"{self._base_url}/v1/workflow/stream_resume" body = { "workflow_id": workflow_id, - "event_id": event_id, + "id": event_id, "resume_data": resume_data, "interrupt_type": interrupt_type, } @@ -385,7 +385,7 @@ async def resume( url = f"{self._base_url}/v1/workflow/stream_resume" body = { "workflow_id": workflow_id, - "event_id": event_id, + "id": event_id, "resume_data": resume_data, "interrupt_type": interrupt_type, } diff --git a/examples/websockets_audio_speech.py b/examples/websockets_audio_speech.py index d960f51..2fa9d05 100644 --- a/examples/websockets_audio_speech.py +++ b/examples/websockets_audio_speech.py @@ -5,12 +5,15 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioSpeechCreateClient, - AsyncWebsocketsAudioSpeechEventHandler, + AsyncWebsocketsAudioSpeechClient, + InputTextBufferAppendEvent, + InputTextBufferCompletedEvent, + SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, TokenAuth, setup_logging, ) +from cozepy.log import log_info from cozepy.util import write_pcm_to_wav_file from examples.utils import get_coze_api_base, get_coze_api_token @@ -21,19 +24,31 @@ kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") -class AsyncWebsocketsAudioSpeechEventHandlerSub(AsyncWebsocketsAudioSpeechEventHandler): +# todo review +class AsyncWebsocketsAudioSpeechEventHandlerSub(AsyncWebsocketsAudioSpeechClient.EventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechCreateClient, event: SpeechAudioUpdateEvent): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + log_info("[examples] Input text buffer completed") + + async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechClient, event: SpeechAudioUpdateEvent): self.delta.append(event.data.delta) - async def on_error(self, cli: AsyncWebsocketsAudioSpeechCreateClient, e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: AsyncWebsocketsAudioSpeechClient, e: Exception): + log_info("[examples] Error occurred: %s", e) - async def on_closed(self, cli: AsyncWebsocketsAudioSpeechCreateClient): - print("Speech connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") + self.delta = [] async def main(): @@ -55,8 +70,14 @@ async def main(): text = "你今天好吗? 今天天气不错呀" async with speech() as client: - await client.append(text) - await client.commit() + await client.input_text_buffer_append( + InputTextBufferAppendEvent.Data.model_validate( + { + "delta": text, + } + ) + ) + await client.input_text_buffer_complete() await client.wait() diff --git a/examples/websockets_audio_transcriptions.py b/examples/websockets_audio_transcriptions.py index 3dc47a6..573b2e3 100644 --- a/examples/websockets_audio_transcriptions.py +++ b/examples/websockets_audio_transcriptions.py @@ -5,13 +5,16 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, AudioFormat, + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, TokenAuth, TranscriptionsMessageUpdateEvent, setup_logging, ) +from cozepy.log import log_info from examples.utils import get_coze_api_base, get_coze_api_token coze_log = os.getenv("COZE_LOG") @@ -25,16 +28,21 @@ class AudioTranscriptionsEventHandlerSub(AsyncWebsocketsAudioTranscriptionsEvent Class is not required, you can also use Dict to set callback """ - async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient"): - print("Connection closed") + async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsClient"): + log_info("[examples] Connect closed") - async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsClient", e: Exception): + log_info("[examples] Error occurred: %s", e) async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): - print("Received:", event.data.content) + log_info("[examples] Received: %s", event.data.content) + + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + ): + log_info("[examples] Input audio buffer completed") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -72,9 +80,15 @@ async def main(): # Create and connect WebSocket client async with transcriptions() as client: - async for data in speech_stream(): - await client.append(data) - await client.commit() + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/examples/websockets_chat.py b/examples/websockets_chat.py index b95d93e..8f4b3d0 100644 --- a/examples/websockets_chat.py +++ b/examples/websockets_chat.py @@ -5,18 +5,22 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, AudioFormat, ConversationAudioDeltaEvent, + ConversationChatCompletedEvent, ConversationChatCreatedEvent, ConversationMessageDeltaEvent, + InputAudioBufferAppendEvent, TokenAuth, + ToolOutput, setup_logging, ) from cozepy.log import log_info from cozepy.util import write_pcm_to_wav_file +from cozepy.websockets.chat import ConversationChatRequiresActionEvent, ConversationChatSubmitToolOutputsEvent from examples.utils import get_coze_api_base, get_coze_api_token coze_log = os.getenv("COZE_LOG") @@ -27,30 +31,56 @@ class AsyncWebsocketsChatEventHandlerSub(AsyncWebsocketsChatEventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_conversation_chat_created( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationChatCreatedEvent - ): - log_info("ChatCreated") + async def on_error(self, cli: AsyncWebsocketsAudioTranscriptionsClient, e: Exception): + import traceback - async def on_conversation_message_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationMessageDeltaEvent - ): + log_info(f"Error occurred: {str(e)}") + log_info(f"Stack trace:\n{traceback.format_exc()}") + + async def on_conversation_chat_created(self, cli: AsyncWebsocketsChatClient, event: ConversationChatCreatedEvent): + log_info("[examples] asr completed, logid=%s", event.detail.logid) + + async def on_conversation_message_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationMessageDeltaEvent): print("Received:", event.data.content) - async def on_conversation_audio_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationAudioDeltaEvent + async def on_conversation_chat_requires_action( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent ): - self.delta.append(event.data.get_audio()) + def fake_run_local_plugin(): + # this is just fake outputs + return event.data.required_action.submit_tool_outputs.tool_calls[0].function.arguments + + fake_output = fake_run_local_plugin() + await cli.conversation_chat_submit_tool_outputs( + ConversationChatSubmitToolOutputsEvent.Data.model_validate( + { + "chat_id": event.data.id, + "tool_outputs": [ + ToolOutput.model_validate( + { + "tool_call_id": event.data.required_action.submit_tool_outputs.tool_calls[0].id, + "output": fake_output, + } + ) + ], + } + ) + ) - async def on_error(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient, e: Exception): - log_info(f"Error occurred: {str(e)}") + async def on_conversation_audio_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationAudioDeltaEvent): + self.delta.append(event.data.get_audio()) - async def on_closed(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient): - print("Chat connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -73,6 +103,7 @@ async def main(): coze_api_token = get_coze_api_token() coze_api_base = get_coze_api_base() bot_id = os.getenv("COZE_BOT_ID") + text = os.getenv("COZE_TEXT") or "你今天好吗? 今天天气不错呀" # Initialize Coze client coze = AsyncCoze( @@ -80,7 +111,7 @@ async def main(): base_url=coze_api_base, ) # Initialize Audio - speech_stream = wrap_coze_speech_to_iterator(coze, "你今天好吗? 今天天气不错呀") + speech_stream = wrap_coze_speech_to_iterator(coze, text) chat = coze.websockets.chat.create( bot_id=bot_id, @@ -91,10 +122,15 @@ async def main(): # Create and connect WebSocket client async with chat() as client: # Read and send audio data - async for data in speech_stream(): - await client.append(data) - await client.commit() - log_info("Audio Committed") + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/poetry.lock b/poetry.lock index bd8b9cf..b47fe41 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1066,38 +1066,6 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] -[[package]] -name = "websocket-client" -version = "1.6.1" -description = "WebSocket client for Python with low level API options" -optional = false -python-versions = ">=3.7" -files = [ - {file = "websocket-client-1.6.1.tar.gz", hash = "sha256:c951af98631d24f8df89ab1019fc365f2227c0892f12fd150e935607c79dd0dd"}, - {file = "websocket_client-1.6.1-py3-none-any.whl", hash = "sha256:f1f9f2ad5291f0225a49efad77abf9e700b6fef553900623060dad6e26503b9d"}, -] - -[package.extras] -docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"] -optional = ["python-socks", "wsaccel"] -test = ["websockets"] - -[[package]] -name = "websocket-client" -version = "1.8.0" -description = "WebSocket client for Python with low level API options" -optional = false -python-versions = ">=3.8" -files = [ - {file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"}, - {file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"}, -] - -[package.extras] -docs = ["Sphinx (>=6.0)", "myst-parser (>=2.0.0)", "sphinx-rtd-theme (>=1.1.0)"] -optional = ["python-socks", "wsaccel"] -test = ["websockets"] - [[package]] name = "websockets" version = "13.1" @@ -1289,4 +1257,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "9a538c47996a060fe0abd622864b4e417052740cf83499a57c07c8af99c5e06c" +content-hash = "4c493526a75fb1ab0bcfd34d8f9e24986f32454ef33a6e99530ec75dbbc8a900" diff --git a/tests/test_audio_translations.py b/tests/test_audio_translations.py index ef29605..b494cb5 100644 --- a/tests/test_audio_translations.py +++ b/tests/test_audio_translations.py @@ -6,7 +6,7 @@ from tests.test_util import logid_key -def mock_create_translation(respx_mock): +def mock_create_transcriptions(respx_mock): logid = random_hex(10) raw_response = httpx.Response( 200, @@ -25,11 +25,11 @@ def mock_create_translation(respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") -class TestAudioTranslation: - def test_sync_translation_create(self, respx_mock): +class TestSyncAudioTranscriptions: + def test_sync_transcriptions_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = coze.audio.transcriptions.create(file=("filename", "content")) assert res @@ -39,11 +39,11 @@ def test_sync_translation_create(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio -class TestAsyncAudioTranslation: - async def test_async_translation_create(self, respx_mock): +class TestAsyncAudioTranscriptions: + async def test_async_transcriptions_create(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = await coze.audio.transcriptions.create(file=("filename", "content")) assert res diff --git a/tests/test_conversations_messages.py b/tests/test_conversations_messages.py index 519579b..2922255 100644 --- a/tests/test_conversations_messages.py +++ b/tests/test_conversations_messages.py @@ -217,7 +217,6 @@ async def test_async_conversations_messages_retrieve(self, respx_mock): async def test_async_conversations_messages_update(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) - mock_msg = mock_update_conversations_messages(respx_mock, Message.build_user_question_text("hi")) message = await coze.conversations.messages.update(conversation_id="conversation id", message_id="message id")