diff --git a/README.md b/README.md index f72ca0b..9aa8209 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,7 @@ conversation = coze.conversations.create( conversation = coze.conversations.retrieve(conversation_id=conversation.id) # append message to conversation -message = coze.conversations.messages.create( +message = coze.conversations.messages.build( # id of conversation conversation_id=conversation.id, content='how are you?', @@ -233,7 +233,7 @@ message = coze.conversations.messages.create( message = coze.conversations.messages.retrieve(conversation_id=conversation.id, message_id=message.id) # update message -coze.conversations.messages.update( +coze.conversations.messages.speech_update( conversation_id=conversation.id, message_id=message.id, content='hey, how are you?', diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 1e615b0..9cd774b 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -1,6 +1,6 @@ from .audio.rooms import CreateRoomResp from .audio.speech import AudioFormat -from .audio.transcriptions import CreateTranslationResp +from .audio.transcriptions import CreateTranscriptionsResp from .audio.voices import Voice from .auth import ( AsyncDeviceOAuthApp, @@ -92,27 +92,27 @@ from .templates import TemplateDuplicateResp, TemplateEntityType from .version import VERSION from .websockets.audio.speech import ( - AsyncWebsocketsAudioSpeechCreateClient, + AsyncWebsocketsAudioSpeechClient, AsyncWebsocketsAudioSpeechEventHandler, InputTextBufferAppendEvent, - InputTextBufferCommitEvent, - InputTextBufferCommittedEvent, + InputTextBufferCompletedEvent, + InputTextBufferCompleteEvent, SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, SpeechUpdateEvent, ) from .websockets.audio.transcriptions import ( - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, InputAudioBufferAppendEvent, - InputAudioBufferCommitEvent, - InputAudioBufferCommittedEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, TranscriptionsMessageCompletedEvent, TranscriptionsMessageUpdateEvent, TranscriptionsUpdateEvent, ) from .websockets.chat import ( - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, ConversationAudioDeltaEvent, ConversationChatCompletedEvent, @@ -143,7 +143,7 @@ "Voice", "AudioFormat", # audio.transcriptions - "CreateTranslationResp", + "CreateTranscriptionsResp", # auth "AsyncDeviceOAuthApp", "AsyncJWTOAuthApp", @@ -217,29 +217,29 @@ "WebsocketsEvent", # websockets.audio.speech "InputTextBufferAppendEvent", - "InputTextBufferCommitEvent", + "InputTextBufferCompleteEvent", "SpeechUpdateEvent", - "InputTextBufferCommittedEvent", + "InputTextBufferCompletedEvent", "SpeechAudioUpdateEvent", "SpeechAudioCompletedEvent", "AsyncWebsocketsAudioSpeechEventHandler", - "AsyncWebsocketsAudioSpeechCreateClient", + "AsyncWebsocketsAudioSpeechClient", # websockets.audio.transcriptions "InputAudioBufferAppendEvent", - "InputAudioBufferCommitEvent", + "InputAudioBufferCompleteEvent", "TranscriptionsUpdateEvent", - "InputAudioBufferCommittedEvent", + "InputAudioBufferCompletedEvent", "TranscriptionsMessageUpdateEvent", "TranscriptionsMessageCompletedEvent", "AsyncWebsocketsAudioTranscriptionsEventHandler", - "AsyncWebsocketsAudioTranscriptionsCreateClient", + "AsyncWebsocketsAudioTranscriptionsClient", # websockets.chat "ConversationChatCreatedEvent", "ConversationMessageDeltaEvent", "ConversationAudioDeltaEvent", "ConversationChatCompletedEvent", "AsyncWebsocketsChatEventHandler", - "AsyncWebsocketsChatCreateClient", + "AsyncWebsocketsChatClient", # workflows.runs "WorkflowRunResult", "WorkflowEventType", diff --git a/cozepy/audio/transcriptions/__init__.py b/cozepy/audio/transcriptions/__init__.py index 7f37181..bac1f82 100644 --- a/cozepy/audio/transcriptions/__init__.py +++ b/cozepy/audio/transcriptions/__init__.py @@ -7,7 +7,7 @@ from cozepy.util import remove_url_trailing_slash -class CreateTranslationResp(CozeModel): +class CreateTranscriptionsResp(CozeModel): # The text of translation text: str @@ -23,18 +23,18 @@ def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" headers: Optional[dict] = kwargs.get("headers") files = {"file": _try_fix_file(file)} return self._requester.request( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) @@ -53,16 +53,16 @@ async def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" files = {"file": _try_fix_file(file)} headers: Optional[dict] = kwargs.get("headers") return await self._requester.arequest( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) diff --git a/cozepy/websockets/__init__.py b/cozepy/websockets/__init__.py index d433773..6353636 100644 --- a/cozepy/websockets/__init__.py +++ b/cozepy/websockets/__init__.py @@ -2,8 +2,31 @@ from cozepy.request import Requester from cozepy.util import http_base_url_to_ws, remove_url_trailing_slash -from .audio import AsyncWebsocketsAudioClient -from .chat import AsyncWebsocketsChatClient +from .audio import AsyncWebsocketsAudioClient, WebsocketsAudioClient +from .chat import AsyncWebsocketsChatBuildClient, WebsocketsChatBuildClient + + +class WebsocketsClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = http_base_url_to_ws(remove_url_trailing_slash(base_url)) + self._auth = auth + self._requester = requester + + @property + def audio(self) -> WebsocketsAudioClient: + return WebsocketsAudioClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def chat(self) -> WebsocketsChatBuildClient: + return WebsocketsChatBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsClient(object): @@ -21,8 +44,8 @@ def audio(self) -> AsyncWebsocketsAudioClient: ) @property - def chat(self) -> AsyncWebsocketsChatClient: - return AsyncWebsocketsChatClient( + def chat(self) -> AsyncWebsocketsChatBuildClient: + return AsyncWebsocketsChatBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/__init__.py b/cozepy/websockets/audio/__init__.py index 3f05e92..a02f3ea 100644 --- a/cozepy/websockets/audio/__init__.py +++ b/cozepy/websockets/audio/__init__.py @@ -1,8 +1,31 @@ from cozepy.auth import Auth from cozepy.request import Requester -from .speech import AsyncWebsocketsAudioSpeechClient -from .transcriptions import AsyncWebsocketsAudioTranscriptionsClient +from .speech import AsyncWebsocketsAudioSpeechBuildClient, WebsocketsAudioSpeechBuildClient +from .transcriptions import AsyncWebsocketsAudioTranscriptionsBuildClient, WebsocketsAudioTranscriptionsBuildClient + + +class WebsocketsAudioClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = base_url + self._auth = auth + self._requester = requester + + @property + def transcriptions(self) -> "WebsocketsAudioTranscriptionsBuildClient": + return WebsocketsAudioTranscriptionsBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def speech(self) -> "WebsocketsAudioSpeechBuildClient": + return WebsocketsAudioSpeechBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsAudioClient(object): @@ -12,16 +35,16 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester @property - def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsClient": - return AsyncWebsocketsAudioTranscriptionsClient( + def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsBuildClient": + return AsyncWebsocketsAudioTranscriptionsBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, ) @property - def speech(self) -> "AsyncWebsocketsAudioSpeechClient": - return AsyncWebsocketsAudioSpeechClient( + def speech(self) -> "AsyncWebsocketsAudioSpeechBuildClient": + return AsyncWebsocketsAudioSpeechBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/speech/__init__.py b/cozepy/websockets/audio/speech/__init__.py index 8190d10..daa4159 100644 --- a/cozepy/websockets/audio/speech/__init__.py +++ b/cozepy/websockets/audio/speech/__init__.py @@ -10,6 +10,8 @@ from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -25,28 +27,28 @@ class Data(BaseModel): # req -class InputTextBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMIT +class InputTextBufferCompleteEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETE # req class SpeechUpdateEvent(WebsocketsEvent): - class OpusConfig(object): + class OpusConfig(BaseModel): bitrate: Optional[int] = None use_cbr: Optional[bool] = None frame_size_ms: Optional[float] = None - class PCMConfig(object): + class PCMConfig(BaseModel): sample_rate: Optional[int] = None - class OutputAudio(object): + class OutputAudio(BaseModel): codec: Optional[str] - pcm_config: Optional["SpeechUpdateEvent.PCMConfig"] - opus_config: Optional["SpeechUpdateEvent.OpusConfig"] - speech_rate: Optional[int] - voice_id: Optional[str] + pcm_config: Optional["SpeechUpdateEvent.PCMConfig"] = None + opus_config: Optional["SpeechUpdateEvent.OpusConfig"] = None + speech_rate: Optional[int] = None + voice_id: Optional[str] = None - class Data: + class Data(BaseModel): output_audio: "SpeechUpdateEvent.OutputAudio" type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE @@ -54,8 +56,8 @@ class Data: # resp -class InputTextBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED +class InputTextBufferCompletedEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED # resp @@ -76,40 +78,34 @@ class SpeechAudioCompletedEvent(WebsocketsEvent): type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_COMPLETED -class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_text_buffer_committed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: InputTextBufferCommittedEvent - ): +class WebsocketsAudioSpeechEventHandler(WebsocketsBaseEventHandler): + def on_input_text_buffer_completed(self, cli: "WebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent): pass - async def on_speech_audio_update( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioUpdateEvent - ): + def on_speech_audio_update(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): pass - async def on_speech_audio_completed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioCompletedEvent - ): + def on_speech_audio_completed(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent): pass -class AsyncWebsocketsAudioSpeechCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioSpeechClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioSpeechEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: on_event.on_input_text_buffer_committed, - WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, - WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, - } + if isinstance(on_event, WebsocketsAudioSpeechEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) super().__init__( base_url=base_url, auth=auth, @@ -120,31 +116,144 @@ def __init__( **kwargs, ) - async def append(self, text: str) -> None: - await self._input_queue.put( - InputTextBufferAppendEvent.model_validate( + def input_text_buffer_append(self, data: InputTextBufferAppendEvent) -> None: + self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + def input_text_buffer_complete(self) -> None: + self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) + + def speech_update(self, event: SpeechUpdateEvent) -> None: + self._input_queue.put(event) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + logid = message.get("logid") or "" + event_type = message.get("type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) + if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: + delta_base64 = data.get("delta") + if delta_base64 is None: + raise ValueError("Missing 'delta' in event data") + return SpeechAudioUpdateEvent.model_validate( { - "data": InputTextBufferAppendEvent.Data.model_validate( + "event_id": event_id, + "logid": logid, + "data": SpeechAudioUpdateEvent.Data.model_validate( { - "delta": text, + "delta": base64.b64decode(delta_base64), } - ) + ), } ) + elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: + return SpeechAudioCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, logid) + return None + + +class WebsocketsAudioSpeechBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, *, on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs + ) -> WebsocketsAudioSpeechClient: + return WebsocketsAudioSpeechClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputTextBufferCommitEvent.model_validate({})) - async def update(self, event: SpeechUpdateEvent) -> None: - await self._input_queue.put(event) +class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + pass + + async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): + pass + + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioSpeechClient(AsyncWebsocketsBaseClient): + class EventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + pass + + async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): + pass + + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + pass + + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union["AsyncWebsocketsAudioSpeechClient.EventHandler", Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioSpeechClient.EventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/speech", + on_event=on_event, + wait_events=[WebsocketsEventType.SPEECH_AUDIO_COMPLETED], + **kwargs, + ) + + async def input_text_buffer_append(self, data: InputTextBufferAppendEvent) -> None: + await self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + async def input_text_buffer_complete(self) -> None: + await self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) + + async def speech_update(self, data: SpeechUpdateEvent.Data) -> None: + await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" + logid = message.get("logid") or "" event_type = message.get("type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: - return InputTextBufferCommittedEvent.model_validate({"event_id": event_id}) + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate({"event_id": event_id, "logid": logid}) if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: delta_base64 = data.get("delta") if delta_base64 is None: @@ -152,6 +261,7 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: return SpeechAudioUpdateEvent.model_validate( { "event_id": event_id, + "logid": logid, "data": SpeechAudioUpdateEvent.Data.model_validate( { "delta": base64.b64decode(delta_base64), @@ -160,22 +270,30 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: } ) elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: - return SpeechAudioCompletedEvent.model_validate({"event_id": event_id}) + return SpeechAudioCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, logid) return None -class AsyncWebsocketsAudioSpeechClient: +class AsyncWebsocketsAudioSpeechBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, *, on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs - ) -> AsyncWebsocketsAudioSpeechCreateClient: - return AsyncWebsocketsAudioSpeechCreateClient( + self, + *, + on_event: Union[AsyncWebsocketsAudioSpeechClient.EventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> AsyncWebsocketsAudioSpeechClient: + return AsyncWebsocketsAudioSpeechClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/transcriptions/__init__.py b/cozepy/websockets/audio/transcriptions/__init__.py index e6dfb4c..cb7a379 100644 --- a/cozepy/websockets/audio/transcriptions/__init__.py +++ b/cozepy/websockets/audio/transcriptions/__init__.py @@ -11,6 +11,8 @@ from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -30,8 +32,8 @@ def serialize_delta(self, delta: bytes, _info): # req -class InputAudioBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMIT +class InputAudioBufferCompleteEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETE # req @@ -51,8 +53,8 @@ class Data(BaseModel): # resp -class InputAudioBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED +class InputAudioBufferCompletedEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED # resp @@ -69,40 +71,40 @@ class TranscriptionsMessageCompletedEvent(WebsocketsEvent): type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED -class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_audio_buffer_committed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: InputAudioBufferCommittedEvent +class WebsocketsAudioTranscriptionsEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): pass - async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + def on_transcriptions_message_update( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): pass - async def on_transcriptions_message_completed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageCompletedEvent + def on_transcriptions_message_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent ): pass -class AsyncWebsocketsAudioTranscriptionsCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioTranscriptionsClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED: on_event.on_input_audio_buffer_committed, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, - } + if isinstance(on_event, WebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) super().__init__( base_url=base_url, auth=auth, @@ -113,31 +115,131 @@ def __init__( **kwargs, ) - async def update(self, data: TranscriptionsUpdateEvent.InputAudio) -> None: - await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + def transcriptions_update(self, data: TranscriptionsUpdateEvent.Data) -> None: + self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("type") or "" + logid = message.get("logid") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE.value: + return TranscriptionsMessageUpdateEvent.model_validate( { - "data": InputAudioBufferAppendEvent.Data.model_validate( + "event_id": event_id, + "logid": logid, + "data": TranscriptionsMessageUpdateEvent.Data.model_validate( { - "delta": delta, + "content": data.get("content") or "", } - ) + ), + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, } ) + else: + log_warning("[v1/audio/transcriptions] unknown event=%s, logid=%s", event_type, logid) + return None + + +class WebsocketsAudioTranscriptionsBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsAudioTranscriptionsClient: + return WebsocketsAudioTranscriptionsClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, + **kwargs, + ) + + +class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_transcriptions_message_update( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + ): + pass + + async def on_transcriptions_message_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioTranscriptionsClient(AsyncWebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/transcriptions", + on_event=on_event, + wait_events=[WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED], + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + async def transcriptions_update(self, data: TranscriptionsUpdateEvent.InputAudio) -> None: + await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" event_type = message.get("type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: - return TranscriptionsMessageCompletedEvent.model_validate( + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( { "event_id": event_id, } @@ -153,14 +255,18 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ), } ) - elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED.value: - pass + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "event_id": event_id, + } + ) else: log_warning("[v1/audio/transcriptions] unknown event=%s", event_type) return None -class AsyncWebsocketsAudioTranscriptionsClient: +class AsyncWebsocketsAudioTranscriptionsBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -171,8 +277,8 @@ def create( *, on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsAudioTranscriptionsCreateClient: - return AsyncWebsocketsAudioTranscriptionsCreateClient( + ) -> AsyncWebsocketsAudioTranscriptionsClient: + return AsyncWebsocketsAudioTranscriptionsClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 1bcb1b9..b4f4687 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -5,15 +5,26 @@ from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash -from cozepy.websockets.audio.transcriptions import InputAudioBufferAppendEvent, InputAudioBufferCommitEvent +from cozepy.websockets.audio.transcriptions import ( + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, +) from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) +# req todo +class ChatUpdateEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATE + + # resp class ConversationChatCreatedEvent(WebsocketsEvent): type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED @@ -26,6 +37,11 @@ class ConversationMessageDeltaEvent(WebsocketsEvent): data: Message +# resp todo +class ConversationChatRequiresActionEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION + + # resp class ConversationAudioDeltaEvent(WebsocketsEvent): type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_DELTA @@ -38,29 +54,177 @@ class ConversationChatCompletedEvent(WebsocketsEvent): data: Chat +class WebsocketsChatEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent): + pass + + def on_conversation_chat_created(self, cli: "WebsocketsChatClient", event: ConversationChatCreatedEvent): + pass + + def on_conversation_message_delta(self, cli: "WebsocketsChatClient", event: ConversationMessageDeltaEvent): + pass + + def on_conversation_chat_requires_action( + self, cli: "WebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass + + def on_conversation_audio_delta(self, cli: "WebsocketsChatClient", event: ConversationAudioDeltaEvent): + pass + + def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: ConversationChatCompletedEvent): + pass + + +class WebsocketsChatClient(WebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, WebsocketsChatEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/chat", + query={ + "bot_id": bot_id, + }, + on_event=on_event, + wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], + **kwargs, + ) + + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + logid = message.get("logid") or "" + event_type = message.get("type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + return ConversationChatCreatedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": Chat.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: + return ConversationMessageDeltaEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": Message.model_validate(data), # todo + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: + return ConversationAudioDeltaEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: + return ConversationChatCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": Chat.model_validate(data), + } + ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, logid) + return None + + +class WebsocketsChatBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsChatClient: + return WebsocketsChatClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + bot_id=bot_id, + on_event=on_event, + **kwargs, + ) + + class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler): - def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCreatedEvent): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCreatedEvent): pass - def on_conversation_message_delta( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationMessageDeltaEvent + async def on_conversation_message_delta( + self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageDeltaEvent ): pass - # def on_conversation_chat_requires_action(self, cli: 'AsyncWebsocketsChatCreateClient', - # event: ConversationChatRequiresActionEvent): - # pass + async def on_conversation_chat_requires_action( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass - def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationAudioDeltaEvent): + async def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioDeltaEvent): pass - def on_conversation_chat_completed( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCompletedEvent + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent ): pass -class AsyncWebsocketsChatCreateClient(AsyncWebsocketsBaseClient): +class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient): def __init__( self, base_url: str, @@ -71,15 +235,16 @@ def __init__( **kwargs, ): if isinstance(on_event, AsyncWebsocketsChatEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, - WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, - # EventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, - WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, - WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, - } + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) super().__init__( base_url=base_url, auth=auth, @@ -93,63 +258,75 @@ def __init__( **kwargs, ) - # async def update(self, event: TranscriptionsUpdateEventInputAudio) -> None: - # await self._input_queue.put(TranscriptionsUpdateEvent.load(event)) + async def chat_update(self, event: ChatUpdateEvent) -> None: + # TODO + await self._input_queue.put(event) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( - { - "data": InputAudioBufferAppendEvent.Data.model_validate( - { - "delta": delta, - } - ) - } - ) - ) + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" + logid = message.get("logid") or "" event_type = message.get("type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: return ConversationChatCreatedEvent.model_validate( { - "data": Chat.model_validate(data), "event_id": event_id, + "logid": logid, + "data": Chat.model_validate(data), } ) if event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: return ConversationMessageDeltaEvent.model_validate( { + "event_id": event_id, + "logid": logid, "data": Message.model_validate(data), + } + ) + if event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + # TODO "event_id": event_id, + "logid": logid, + "data": Message.model_validate(data), } ) elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: return ConversationAudioDeltaEvent.model_validate( { - "data": Message.model_validate(data), "event_id": event_id, + "logid": logid, + "data": Message.model_validate(data), } ) elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: return ConversationChatCompletedEvent.model_validate( { - "data": Chat.model_validate(data), "event_id": event_id, + "logid": logid, + "data": Chat.model_validate(data), } ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, logid) return None -class AsyncWebsocketsChatClient(object): +class AsyncWebsocketsChatBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -161,8 +338,8 @@ def create( bot_id: str, on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsChatCreateClient: - return AsyncWebsocketsChatCreateClient( + ) -> AsyncWebsocketsChatClient: + return AsyncWebsocketsChatClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 9e787a6..8b0ad9b 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -1,17 +1,20 @@ import abc import asyncio import json +import queue +import threading +import traceback from abc import ABC -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, contextmanager from enum import Enum from typing import Callable, Dict, List, Optional -import websockets +import websockets.asyncio.client +import websockets.sync.client from websockets import InvalidStatus -from websockets.asyncio.connection import Connection from cozepy import Auth, CozeAPIError -from cozepy.log import log_debug, log_info +from cozepy.log import log_debug, log_error, log_info from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -20,49 +23,276 @@ class WebsocketsEventType(str, Enum): # common - ERROR = "error" - CLOSED = "closed" + CLIENT_ERROR = "client_error" # sdk error + CLOSED = "closed" # connection closed + + # error + ERROR = "error" # received error event # v1/audio/speech # req - INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" - INPUT_TEXT_BUFFER_COMMIT = "input_text_buffer.commit" - SPEECH_UPDATE = "speech.update" + INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" # send text to server + INPUT_TEXT_BUFFER_COMPLETE = ( + "input_text_buffer.complete" # no text to send, after audio all received, can close connection + ) + SPEECH_UPDATE = "speech.update" # send speech config to server # resp # v1/audio/speech - INPUT_TEXT_BUFFER_COMMITTED = "input_text_buffer.committed" # ignored - SPEECH_AUDIO_UPDATE = "speech.audio.update" - SPEECH_AUDIO_COMPLETED = "speech.audio.completed" + INPUT_TEXT_BUFFER_COMPLETED = "input_text_buffer.completed" # received `input_text_buffer.complete` event + SPEECH_AUDIO_UPDATE = "speech.audio.update" # received `speech.update` event + SPEECH_AUDIO_COMPLETED = "speech.audio.completed" # all audio received, can close connection # v1/audio/transcriptions # req - INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - TRANSCRIPTIONS_UPDATE = "transcriptions.update" + INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + INPUT_AUDIO_BUFFER_COMPLETE = ( + "input_audio_buffer.complete" # no audio to send, after text all received, can close connection + ) + TRANSCRIPTIONS_UPDATE = "transcriptions.update" # send transcriptions config to server # resp - INPUT_AUDIO_BUFFER_COMMITTED = "input_audio_buffer.committed" # ignored - TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" - TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" + INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" # received `transcriptions.update` event + TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" # all audio received, can close connection # v1/chat # req - # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - # INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - CHAT_UPDATE = "chat.update" + # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + # INPUT_AUDIO_BUFFER_COMPLETE = "input_audio_buffer.complete" # no audio send, start chat + CHAT_UPDATE = "chat.update" # send chat config to server # resp - CONVERSATION_CHAT_CREATED = "conversation.chat.created" - CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" - CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" - CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" - CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" + # INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started + CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" # get agent text message update + CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" # need plugin submit + CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update + CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" # all message received, can close connection class WebsocketsEvent(CozeModel, ABC): - event_id: Optional[str] = None type: WebsocketsEventType + event_id: Optional[str] = None + logid: Optional[str] = None + + +class WebsocketsErrorEvent(WebsocketsEvent): + type: WebsocketsEventType = WebsocketsEventType.ERROR + data: CozeAPIError + + +class WebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + path: str, + query: Optional[Dict[str, str]] = None, + on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, + wait_events: Optional[List[WebsocketsEventType]] = None, + **kwargs, + ): + self._state = self.State.INITIALIZED + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + self._path = path + self._ws_url = self._base_url + "/" + path + if query: + self._ws_url += "?" + "&".join([f"{k}={v}" for k, v in query.items()]) + self._on_event = on_event.copy() if on_event else {} + self._headers = kwargs.get("headers") + self._wait_events = wait_events.copy() if wait_events else [] + + self._input_queue: queue.Queue[Optional[WebsocketsEvent]] = queue.Queue() + self._ws: Optional[websockets.sync.client.ClientConnection] = None + self._send_thread: Optional[threading.Thread] = None + self._receive_thread: Optional[threading.Thread] = None + self._completed_events = set() + self._completed_event = threading.Event() + + @contextmanager + def __call__(self): + try: + self.connect() + yield self + finally: + self.close() + + def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING + headers = { + "Authorization": f"Bearer {self._auth.token}", + "X-Coze-Client-User-Agent": coze_client_user_agent(), + **(self._headers or {}), + } + try: + self._ws = websockets.sync.client.connect( + self._ws_url, + user_agent_header=user_agent(), + additional_headers=headers, + ) + self._state = self.State.CONNECTED + log_info("[%s] connected to websocket", self._path) + + self._send_thread = threading.Thread(target=self._send_loop) + self._receive_thread = threading.Thread(target=self._receive_loop) + self._send_thread.start() + self._receive_thread.start() + except InvalidStatus as e: + raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e + + def wait(self, events: Optional[List[WebsocketsEventType]] = None, wait_all=True) -> None: + if events is None: + events = self._wait_events + self._wait_completed(events, wait_all=wait_all) + + def on(self, event_type: WebsocketsEventType, handler: Callable): + self._on_event[event_type] = handler + + def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING + self._close() + self._state = self.State.CLOSED + + def _send_loop(self) -> None: + try: + while True: + if event := self._input_queue.get(): + self._send_event(event) + self._input_queue.task_done() + except Exception as e: + self._handle_error(e) + + def _receive_loop(self) -> None: + try: + while True: + if not self._ws: + log_debug("[%s] empty websocket conn, close", self._path) + break + + data = self._ws.recv() + message = json.loads(data) + event_type = message.get("type") + log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) + + if handler := self._on_event.get(event_type): + if event := self._load_all_event(message): + handler(self, event) + self._completed_events.add(event_type) + self._completed_event.set() + except Exception as e: + self._handle_error(e) + + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("type") or "" + logid = message.get("logid") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": CozeAPIError(code, msg, logid), + } + ) + return self._load_event(message) + + @abc.abstractmethod + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... + + def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + while True: + if wait_all: + # 所有事件都需要完成 + if self._completed_events == set(events): + break + else: + # 任意一个事件完成即可 + if any(event in self._completed_events for event in events): + break + self._completed_event.wait() + self._completed_event.clear() + + def _handle_error(self, error: Exception) -> None: + if handler := self._on_event.get(WebsocketsEventType.ERROR): + handler(self, error) + else: + raise error + + def _close(self) -> None: + log_info("[%s] connect closed", self._path) + if self._send_thread: + self._send_thread.join() + if self._receive_thread: + self._receive_thread.join() + + if self._ws: + self._ws.close() + self._ws = None + + while not self._input_queue.empty(): + self._input_queue.get() + + if handler := self._on_event.get(WebsocketsEventType.CLOSED): + handler(self) + + def _send_event(self, event: WebsocketsEvent) -> None: + log_debug("[%s] send event, type=%s", self._path, event.type.value) + if self._ws: + self._ws.send(event.model_dump_json()) + + +class WebsocketsBaseEventHandler(object): + def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_closed(self, cli: "WebsocketsBaseClient"): + pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res class AsyncWebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + def __init__( self, base_url: str, @@ -74,6 +304,7 @@ def __init__( wait_events: Optional[List[WebsocketsEventType]] = None, **kwargs, ): + self._state = self.State.INITIALIZED self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester @@ -86,7 +317,7 @@ def __init__( self._wait_events = wait_events.copy() if wait_events else [] self._input_queue: asyncio.Queue[Optional[WebsocketsEvent]] = asyncio.Queue() - self._ws: Optional[Connection] = None + self._ws: Optional[websockets.asyncio.client.ClientConnection] = None self._send_task: Optional[asyncio.Task] = None self._receive_task: Optional[asyncio.Task] = None @@ -99,21 +330,25 @@ async def __call__(self): await self.close() async def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING headers = { "Authorization": f"Bearer {self._auth.token}", "X-Coze-Client-User-Agent": coze_client_user_agent(), **(self._headers or {}), } try: - self._ws = await websockets.connect( + self._ws = await websockets.asyncio.client.connect( self._ws_url, user_agent_header=user_agent(), additional_headers=headers, ) + self._state = self.State.CONNECTED log_info("[%s] connected to websocket", self._path) - self._receive_task = asyncio.create_task(self._receive_loop()) self._send_task = asyncio.create_task(self._send_loop()) + self._receive_task = asyncio.create_task(self._receive_loop()) except InvalidStatus as e: raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e @@ -126,13 +361,17 @@ def on(self, event_type: WebsocketsEventType, handler: Callable): self._on_event[event_type] = handler async def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING await self._close() + self._state = self.State.CLOSED async def _send_loop(self) -> None: try: while True: - if event := await self._input_queue.get(): - await self._send_event(event) + event = await self._input_queue.get() + await self._send_event(event) self._input_queue.task_done() except Exception as e: await self._handle_error(e) @@ -150,46 +389,69 @@ async def _receive_loop(self) -> None: log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) if handler := self._on_event.get(event_type): - if event := self._load_event(message): + if event := self._load_all_event(message): await handler(self, event) except Exception as e: await self._handle_error(e) + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("type") or "" + logid = message.get("logid") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "event_id": event_id, + "logid": logid, + "data": CozeAPIError(code, msg, logid), + } + ) + return self._load_event(message) + @abc.abstractmethod def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... - async def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + async def _wait_completed(self, wait_events: List[WebsocketsEventType], wait_all: bool) -> None: future: asyncio.Future[None] = asyncio.Future() - original_handlers = {} completed_events = set() - async def _handle_completed(client, event): - event_type = event.type - completed_events.add(event_type) + def _wrap_handler(event_type: WebsocketsEventType, original_handler): + async def wrapped(client, event): + # 先执行原始处理函数 + if original_handler: + await original_handler(client, event) + + # 再检查完成条件 + completed_events.add(event_type) + if wait_all: + # 所有事件都需要完成 + if completed_events == set(wait_events): + if not future.done(): + future.set_result(None) + else: + # 任意一个事件完成即可 + if not future.done(): + future.set_result(None) - if wait_all: - # 所有事件都需要完成 - if completed_events == set(events): - future.set_result(None) - else: - # 任意一个事件完成即可 - future.set_result(None) + return wrapped - # 为每个指定的事件类型临时注册处理函数 - for event_type in events: - original_handlers[event_type] = self._on_event.get(event_type) - self._on_event[event_type] = _handle_completed + # 为每个指定的事件类型包装处理函数 + origin_handlers = {} + for event_type in wait_events: + original_handler = self._on_event.get(event_type) + origin_handlers[event_type] = original_handler + self._on_event[event_type] = _wrap_handler(event_type, original_handler) try: # 等待直到满足完成条件 await future finally: # 恢复所有原来的处理函数 - for event_type, handler in original_handlers.items(): - if handler: - self._on_event[event_type] = handler - else: - self._on_event.pop(event_type, None) + for event_type in wait_events: + if original_handler := origin_handlers.get(event_type): + self._on_event[event_type] = original_handler async def _handle_error(self, error: Exception) -> None: if handler := self._on_event.get(WebsocketsEventType.ERROR): @@ -221,8 +483,22 @@ async def _send_event(self, event: WebsocketsEvent) -> None: class AsyncWebsocketsBaseEventHandler(object): - async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: Exception): - pass + async def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: CozeAPIError): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") async def on_closed(self, cli: "AsyncWebsocketsBaseClient"): pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res diff --git a/examples/websockets_audio_speech.py b/examples/websockets_audio_speech.py index d960f51..2fa9d05 100644 --- a/examples/websockets_audio_speech.py +++ b/examples/websockets_audio_speech.py @@ -5,12 +5,15 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioSpeechCreateClient, - AsyncWebsocketsAudioSpeechEventHandler, + AsyncWebsocketsAudioSpeechClient, + InputTextBufferAppendEvent, + InputTextBufferCompletedEvent, + SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, TokenAuth, setup_logging, ) +from cozepy.log import log_info from cozepy.util import write_pcm_to_wav_file from examples.utils import get_coze_api_base, get_coze_api_token @@ -21,19 +24,31 @@ kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") -class AsyncWebsocketsAudioSpeechEventHandlerSub(AsyncWebsocketsAudioSpeechEventHandler): +# todo review +class AsyncWebsocketsAudioSpeechEventHandlerSub(AsyncWebsocketsAudioSpeechClient.EventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechCreateClient, event: SpeechAudioUpdateEvent): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + log_info("[examples] Input text buffer completed") + + async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechClient, event: SpeechAudioUpdateEvent): self.delta.append(event.data.delta) - async def on_error(self, cli: AsyncWebsocketsAudioSpeechCreateClient, e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: AsyncWebsocketsAudioSpeechClient, e: Exception): + log_info("[examples] Error occurred: %s", e) - async def on_closed(self, cli: AsyncWebsocketsAudioSpeechCreateClient): - print("Speech connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") + self.delta = [] async def main(): @@ -55,8 +70,14 @@ async def main(): text = "你今天好吗? 今天天气不错呀" async with speech() as client: - await client.append(text) - await client.commit() + await client.input_text_buffer_append( + InputTextBufferAppendEvent.Data.model_validate( + { + "delta": text, + } + ) + ) + await client.input_text_buffer_complete() await client.wait() diff --git a/examples/websockets_audio_transcriptions.py b/examples/websockets_audio_transcriptions.py index 3dc47a6..573b2e3 100644 --- a/examples/websockets_audio_transcriptions.py +++ b/examples/websockets_audio_transcriptions.py @@ -5,13 +5,16 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, AudioFormat, + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, TokenAuth, TranscriptionsMessageUpdateEvent, setup_logging, ) +from cozepy.log import log_info from examples.utils import get_coze_api_base, get_coze_api_token coze_log = os.getenv("COZE_LOG") @@ -25,16 +28,21 @@ class AudioTranscriptionsEventHandlerSub(AsyncWebsocketsAudioTranscriptionsEvent Class is not required, you can also use Dict to set callback """ - async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient"): - print("Connection closed") + async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsClient"): + log_info("[examples] Connect closed") - async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsClient", e: Exception): + log_info("[examples] Error occurred: %s", e) async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): - print("Received:", event.data.content) + log_info("[examples] Received: %s", event.data.content) + + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + ): + log_info("[examples] Input audio buffer completed") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -72,9 +80,15 @@ async def main(): # Create and connect WebSocket client async with transcriptions() as client: - async for data in speech_stream(): - await client.append(data) - await client.commit() + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/examples/websockets_chat.py b/examples/websockets_chat.py index b95d93e..e17fbf2 100644 --- a/examples/websockets_chat.py +++ b/examples/websockets_chat.py @@ -5,13 +5,15 @@ from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, AudioFormat, ConversationAudioDeltaEvent, + ConversationChatCompletedEvent, ConversationChatCreatedEvent, ConversationMessageDeltaEvent, + InputAudioBufferAppendEvent, TokenAuth, setup_logging, ) @@ -27,30 +29,32 @@ class AsyncWebsocketsChatEventHandlerSub(AsyncWebsocketsChatEventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_conversation_chat_created( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationChatCreatedEvent - ): - log_info("ChatCreated") + async def on_error(self, cli: AsyncWebsocketsAudioTranscriptionsClient, e: Exception): + import traceback - async def on_conversation_message_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationMessageDeltaEvent - ): + log_info(f"Error occurred: {str(e)}") + log_info(f"Stack trace:\n{traceback.format_exc()}") + + async def on_conversation_chat_created(self, cli: AsyncWebsocketsChatClient, event: ConversationChatCreatedEvent): + log_info("[examples] Chat created, means the AST completed and sent to LLM") + + async def on_conversation_message_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationMessageDeltaEvent): print("Received:", event.data.content) - async def on_conversation_audio_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationAudioDeltaEvent - ): + async def on_conversation_audio_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationAudioDeltaEvent): self.delta.append(event.data.get_audio()) - async def on_error(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient, e: Exception): - log_info(f"Error occurred: {str(e)}") - - async def on_closed(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient): - print("Chat connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -91,10 +95,15 @@ async def main(): # Create and connect WebSocket client async with chat() as client: # Read and send audio data - async for data in speech_stream(): - await client.append(data) - await client.commit() - log_info("Audio Committed") + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/tests/test_audio_translations.py b/tests/test_audio_translations.py index ef29605..b494cb5 100644 --- a/tests/test_audio_translations.py +++ b/tests/test_audio_translations.py @@ -6,7 +6,7 @@ from tests.test_util import logid_key -def mock_create_translation(respx_mock): +def mock_create_transcriptions(respx_mock): logid = random_hex(10) raw_response = httpx.Response( 200, @@ -25,11 +25,11 @@ def mock_create_translation(respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") -class TestAudioTranslation: - def test_sync_translation_create(self, respx_mock): +class TestSyncAudioTranscriptions: + def test_sync_transcriptions_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = coze.audio.transcriptions.create(file=("filename", "content")) assert res @@ -39,11 +39,11 @@ def test_sync_translation_create(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio -class TestAsyncAudioTranslation: - async def test_async_translation_create(self, respx_mock): +class TestAsyncAudioTranscriptions: + async def test_async_transcriptions_create(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = await coze.audio.transcriptions.create(file=("filename", "content")) assert res diff --git a/tests/test_conversations_messages.py b/tests/test_conversations_messages.py index 519579b..9f169d9 100644 --- a/tests/test_conversations_messages.py +++ b/tests/test_conversations_messages.py @@ -139,7 +139,7 @@ def test_sync_conversations_messages_update(self, respx_mock): mock_msg = mock_update_conversations_messages(respx_mock, Message.build_user_question_text("hi")) - message = coze.conversations.messages.update(conversation_id="conversation id", message_id="message id") + message = coze.conversations.messages.speech_update(conversation_id="conversation id", message_id="message id") assert message assert message.response.logid == mock_msg.response.logid assert message.content == mock_msg.content @@ -220,7 +220,9 @@ async def test_async_conversations_messages_update(self, respx_mock): mock_msg = mock_update_conversations_messages(respx_mock, Message.build_user_question_text("hi")) - message = await coze.conversations.messages.update(conversation_id="conversation id", message_id="message id") + message = await coze.conversations.messages.speech_update( + conversation_id="conversation id", message_id="message id" + ) assert message assert message.response.logid == mock_msg.response.logid assert message.content == mock_msg.content