diff --git a/.coderabbit.yaml b/.coderabbit.yaml deleted file mode 100644 index b29c1c3..0000000 --- a/.coderabbit.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json -language: en-US -reviews: - high_level_summary: false # disable auto summary generation diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 1e615b0..eba108f 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -1,6 +1,6 @@ from .audio.rooms import CreateRoomResp from .audio.speech import AudioFormat -from .audio.transcriptions import CreateTranslationResp +from .audio.transcriptions import CreateTranscriptionsResp from .audio.voices import Voice from .auth import ( AsyncDeviceOAuthApp, @@ -92,34 +92,48 @@ from .templates import TemplateDuplicateResp, TemplateEntityType from .version import VERSION from .websockets.audio.speech import ( - AsyncWebsocketsAudioSpeechCreateClient, + AsyncWebsocketsAudioSpeechClient, AsyncWebsocketsAudioSpeechEventHandler, InputTextBufferAppendEvent, - InputTextBufferCommitEvent, - InputTextBufferCommittedEvent, + InputTextBufferCompletedEvent, + InputTextBufferCompleteEvent, SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, SpeechUpdateEvent, + WebsocketsAudioSpeechClient, + WebsocketsAudioSpeechEventHandler, ) from .websockets.audio.transcriptions import ( - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, InputAudioBufferAppendEvent, - InputAudioBufferCommitEvent, - InputAudioBufferCommittedEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, TranscriptionsMessageCompletedEvent, TranscriptionsMessageUpdateEvent, TranscriptionsUpdateEvent, + WebsocketsAudioTranscriptionsClient, + WebsocketsAudioTranscriptionsEventHandler, ) from .websockets.chat import ( - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, + ChatUpdateEvent, ConversationAudioDeltaEvent, ConversationChatCompletedEvent, ConversationChatCreatedEvent, + ConversationChatRequiresActionEvent, + ConversationChatSubmitToolOutputsEvent, ConversationMessageDeltaEvent, + WebsocketsChatClient, + WebsocketsChatEventHandler, ) from .websockets.ws import ( + InputAudio, + OpusConfig, + OutputAudio, + PCMConfig, + WebsocketsErrorEvent, WebsocketsEvent, WebsocketsEventType, ) @@ -143,7 +157,7 @@ "Voice", "AudioFormat", # audio.transcriptions - "CreateTranslationResp", + "CreateTranscriptionsResp", # auth "AsyncDeviceOAuthApp", "AsyncJWTOAuthApp", @@ -212,34 +226,48 @@ "DocumentSourceInfo", "DocumentUpdateRule", "DocumentBase", - # websockets - "WebsocketsEventType", - "WebsocketsEvent", # websockets.audio.speech "InputTextBufferAppendEvent", - "InputTextBufferCommitEvent", + "InputTextBufferCompleteEvent", "SpeechUpdateEvent", - "InputTextBufferCommittedEvent", + "InputTextBufferCompletedEvent", "SpeechAudioUpdateEvent", "SpeechAudioCompletedEvent", + "WebsocketsAudioSpeechEventHandler", + "WebsocketsAudioSpeechClient", "AsyncWebsocketsAudioSpeechEventHandler", - "AsyncWebsocketsAudioSpeechCreateClient", + "AsyncWebsocketsAudioSpeechClient", # websockets.audio.transcriptions "InputAudioBufferAppendEvent", - "InputAudioBufferCommitEvent", + "InputAudioBufferCompleteEvent", "TranscriptionsUpdateEvent", - "InputAudioBufferCommittedEvent", + "InputAudioBufferCompletedEvent", "TranscriptionsMessageUpdateEvent", "TranscriptionsMessageCompletedEvent", + "WebsocketsAudioTranscriptionsEventHandler", + "WebsocketsAudioTranscriptionsClient", "AsyncWebsocketsAudioTranscriptionsEventHandler", - "AsyncWebsocketsAudioTranscriptionsCreateClient", + "AsyncWebsocketsAudioTranscriptionsClient", # websockets.chat + "ChatUpdateEvent", + "ConversationChatSubmitToolOutputsEvent", "ConversationChatCreatedEvent", "ConversationMessageDeltaEvent", + "ConversationChatRequiresActionEvent", "ConversationAudioDeltaEvent", "ConversationChatCompletedEvent", + "WebsocketsChatEventHandler", + "WebsocketsChatClient", "AsyncWebsocketsChatEventHandler", - "AsyncWebsocketsChatCreateClient", + "AsyncWebsocketsChatClient", + # websockets + "WebsocketsEventType", + "WebsocketsEvent", + "WebsocketsErrorEvent", + "InputAudio", + "OpusConfig", + "PCMConfig", + "OutputAudio", # workflows.runs "WorkflowRunResult", "WorkflowEventType", diff --git a/cozepy/audio/transcriptions/__init__.py b/cozepy/audio/transcriptions/__init__.py index 7f37181..bac1f82 100644 --- a/cozepy/audio/transcriptions/__init__.py +++ b/cozepy/audio/transcriptions/__init__.py @@ -7,7 +7,7 @@ from cozepy.util import remove_url_trailing_slash -class CreateTranslationResp(CozeModel): +class CreateTranscriptionsResp(CozeModel): # The text of translation text: str @@ -23,18 +23,18 @@ def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" headers: Optional[dict] = kwargs.get("headers") files = {"file": _try_fix_file(file)} return self._requester.request( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) @@ -53,16 +53,16 @@ async def create( *, file: FileTypes, **kwargs, - ) -> CreateTranslationResp: + ) -> CreateTranscriptionsResp: """ - create translation + create transcriptions :param file: The file to be translated. - :return: create translation result + :return: create transcriptions result """ url = f"{self._base_url}/v1/audio/transcriptions" files = {"file": _try_fix_file(file)} headers: Optional[dict] = kwargs.get("headers") return await self._requester.arequest( - "post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files + "post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files ) diff --git a/cozepy/websockets/__init__.py b/cozepy/websockets/__init__.py index d433773..6353636 100644 --- a/cozepy/websockets/__init__.py +++ b/cozepy/websockets/__init__.py @@ -2,8 +2,31 @@ from cozepy.request import Requester from cozepy.util import http_base_url_to_ws, remove_url_trailing_slash -from .audio import AsyncWebsocketsAudioClient -from .chat import AsyncWebsocketsChatClient +from .audio import AsyncWebsocketsAudioClient, WebsocketsAudioClient +from .chat import AsyncWebsocketsChatBuildClient, WebsocketsChatBuildClient + + +class WebsocketsClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = http_base_url_to_ws(remove_url_trailing_slash(base_url)) + self._auth = auth + self._requester = requester + + @property + def audio(self) -> WebsocketsAudioClient: + return WebsocketsAudioClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def chat(self) -> WebsocketsChatBuildClient: + return WebsocketsChatBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsClient(object): @@ -21,8 +44,8 @@ def audio(self) -> AsyncWebsocketsAudioClient: ) @property - def chat(self) -> AsyncWebsocketsChatClient: - return AsyncWebsocketsChatClient( + def chat(self) -> AsyncWebsocketsChatBuildClient: + return AsyncWebsocketsChatBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/__init__.py b/cozepy/websockets/audio/__init__.py index 3f05e92..a02f3ea 100644 --- a/cozepy/websockets/audio/__init__.py +++ b/cozepy/websockets/audio/__init__.py @@ -1,8 +1,31 @@ from cozepy.auth import Auth from cozepy.request import Requester -from .speech import AsyncWebsocketsAudioSpeechClient -from .transcriptions import AsyncWebsocketsAudioTranscriptionsClient +from .speech import AsyncWebsocketsAudioSpeechBuildClient, WebsocketsAudioSpeechBuildClient +from .transcriptions import AsyncWebsocketsAudioTranscriptionsBuildClient, WebsocketsAudioTranscriptionsBuildClient + + +class WebsocketsAudioClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = base_url + self._auth = auth + self._requester = requester + + @property + def transcriptions(self) -> "WebsocketsAudioTranscriptionsBuildClient": + return WebsocketsAudioTranscriptionsBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) + + @property + def speech(self) -> "WebsocketsAudioSpeechBuildClient": + return WebsocketsAudioSpeechBuildClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + ) class AsyncWebsocketsAudioClient(object): @@ -12,16 +35,16 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester @property - def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsClient": - return AsyncWebsocketsAudioTranscriptionsClient( + def transcriptions(self) -> "AsyncWebsocketsAudioTranscriptionsBuildClient": + return AsyncWebsocketsAudioTranscriptionsBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, ) @property - def speech(self) -> "AsyncWebsocketsAudioSpeechClient": - return AsyncWebsocketsAudioSpeechClient( + def speech(self) -> "AsyncWebsocketsAudioSpeechBuildClient": + return AsyncWebsocketsAudioSpeechBuildClient( base_url=self._base_url, auth=self._auth, requester=self._requester, diff --git a/cozepy/websockets/audio/speech/__init__.py b/cozepy/websockets/audio/speech/__init__.py index 8190d10..426df1f 100644 --- a/cozepy/websockets/audio/speech/__init__.py +++ b/cozepy/websockets/audio/speech/__init__.py @@ -10,6 +10,9 @@ from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + OutputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -20,42 +23,27 @@ class InputTextBufferAppendEvent(WebsocketsEvent): class Data(BaseModel): delta: str - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_APPEND + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_APPEND data: Data # req -class InputTextBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMIT +class InputTextBufferCompleteEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETE # req class SpeechUpdateEvent(WebsocketsEvent): - class OpusConfig(object): - bitrate: Optional[int] = None - use_cbr: Optional[bool] = None - frame_size_ms: Optional[float] = None - - class PCMConfig(object): - sample_rate: Optional[int] = None - - class OutputAudio(object): - codec: Optional[str] - pcm_config: Optional["SpeechUpdateEvent.PCMConfig"] - opus_config: Optional["SpeechUpdateEvent.OpusConfig"] - speech_rate: Optional[int] - voice_id: Optional[str] - - class Data: - output_audio: "SpeechUpdateEvent.OutputAudio" + class Data(BaseModel): + output_audio: Optional[OutputAudio] = None - type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE data: Data # resp -class InputTextBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED +class InputTextBufferCompletedEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED # resp @@ -67,91 +55,185 @@ class Data(BaseModel): def serialize_delta(self, delta: bytes, _info): return base64.b64encode(delta) - type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_UPDATE data: Data # resp class SpeechAudioCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.SPEECH_AUDIO_COMPLETED -class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_text_buffer_committed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: InputTextBufferCommittedEvent - ): +class WebsocketsAudioSpeechEventHandler(WebsocketsBaseEventHandler): + def on_input_text_buffer_completed(self, cli: "WebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent): pass - async def on_speech_audio_update( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioUpdateEvent - ): + def on_speech_audio_update(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): pass - async def on_speech_audio_completed( - self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioCompletedEvent - ): + def on_speech_audio_completed(self, cli: "WebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent): pass -class AsyncWebsocketsAudioSpeechCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioSpeechClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioSpeechEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: on_event.on_input_text_buffer_committed, - WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, - WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, - } + if isinstance(on_event, WebsocketsAudioSpeechEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) super().__init__( base_url=base_url, auth=auth, requester=requester, path="v1/audio/speech", - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.SPEECH_AUDIO_COMPLETED], **kwargs, ) - async def append(self, text: str) -> None: - await self._input_queue.put( - InputTextBufferAppendEvent.model_validate( + def input_text_buffer_append(self, data: InputTextBufferAppendEvent.Data) -> None: + self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + def input_text_buffer_complete(self) -> None: + self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) + + def speech_update(self, event: SpeechUpdateEvent) -> None: + self._input_queue.put(event) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate( { - "data": InputTextBufferAppendEvent.Data.model_validate( + "id": event_id, + "detail": detail, + } + ) + if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: + delta_base64 = data.get("delta") + if delta_base64 is None: + raise ValueError("Missing 'delta' in event data") + return SpeechAudioUpdateEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": SpeechAudioUpdateEvent.Data.model_validate( { - "delta": text, + "delta": base64.b64decode(delta_base64), } - ) + ), + } + ) + elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: + return SpeechAudioCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, } ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) + return None + + +class WebsocketsAudioSpeechBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, *, on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs + ) -> WebsocketsAudioSpeechClient: + return WebsocketsAudioSpeechClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, # type: ignore + **kwargs, + ) + + +class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + pass + + async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioUpdateEvent): + pass + + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioSpeechClient(AsyncWebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioSpeechEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: on_event.on_input_text_buffer_completed, + WebsocketsEventType.SPEECH_AUDIO_UPDATE: on_event.on_speech_audio_update, + WebsocketsEventType.SPEECH_AUDIO_COMPLETED: on_event.on_speech_audio_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/speech", + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.SPEECH_AUDIO_COMPLETED], + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputTextBufferCommitEvent.model_validate({})) + async def input_text_buffer_append(self, data: InputTextBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputTextBufferAppendEvent.model_validate({"data": data})) + + async def input_text_buffer_complete(self) -> None: + await self._input_queue.put(InputTextBufferCompleteEvent.model_validate({})) - async def update(self, event: SpeechUpdateEvent) -> None: - await self._input_queue.put(event) + async def speech_update(self, data: SpeechUpdateEvent.Data) -> None: + await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMMITTED: - return InputTextBufferCommittedEvent.model_validate({"event_id": event_id}) + if event_type == WebsocketsEventType.INPUT_TEXT_BUFFER_COMPLETED: + return InputTextBufferCompletedEvent.model_validate({"id": event_id, "detail": detail}) if event_type == WebsocketsEventType.SPEECH_AUDIO_UPDATE.value: delta_base64 = data.get("delta") if delta_base64 is None: raise ValueError("Missing 'delta' in event data") return SpeechAudioUpdateEvent.model_validate( { - "event_id": event_id, + "id": event_id, + "detail": detail, "data": SpeechAudioUpdateEvent.Data.model_validate( { "delta": base64.b64decode(delta_base64), @@ -160,25 +242,33 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: } ) elif event_type == WebsocketsEventType.SPEECH_AUDIO_COMPLETED.value: - return SpeechAudioCompletedEvent.model_validate({"event_id": event_id}) + return SpeechAudioCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None -class AsyncWebsocketsAudioSpeechClient: +class AsyncWebsocketsAudioSpeechBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, *, on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs - ) -> AsyncWebsocketsAudioSpeechCreateClient: - return AsyncWebsocketsAudioSpeechCreateClient( + self, + *, + on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> AsyncWebsocketsAudioSpeechClient: + return AsyncWebsocketsAudioSpeechClient( base_url=self._base_url, auth=self._auth, requester=self._requester, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/audio/transcriptions/__init__.py b/cozepy/websockets/audio/transcriptions/__init__.py index e6dfb4c..462d77a 100644 --- a/cozepy/websockets/audio/transcriptions/__init__.py +++ b/cozepy/websockets/audio/transcriptions/__init__.py @@ -5,12 +5,14 @@ from cozepy.auth import Auth from cozepy.log import log_warning -from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + InputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) @@ -25,34 +27,27 @@ class Data(BaseModel): def serialize_delta(self, delta: bytes, _info): return base64.b64encode(delta) - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND data: Data # req -class InputAudioBufferCommitEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMIT +class InputAudioBufferCompleteEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETE # req class TranscriptionsUpdateEvent(WebsocketsEvent): - class InputAudio(CozeModel): - format: Optional[str] - codec: Optional[str] - sample_rate: Optional[int] - channel: Optional[int] - bit_depth: Optional[int] - class Data(BaseModel): - input_audio: "TranscriptionsUpdateEvent.InputAudio" + input_audio: Optional[InputAudio] = None - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_UPDATE data: Data # resp -class InputAudioBufferCommittedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED +class InputAudioBufferCompletedEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED # resp @@ -60,92 +55,195 @@ class TranscriptionsMessageUpdateEvent(WebsocketsEvent): class Data(BaseModel): content: str - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE data: Data # resp class TranscriptionsMessageCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED -class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): - async def on_input_audio_buffer_committed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: InputAudioBufferCommittedEvent +class WebsocketsAudioTranscriptionsEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): pass - async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + def on_transcriptions_message_update( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): pass - async def on_transcriptions_message_completed( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageCompletedEvent + def on_transcriptions_message_completed( + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent ): pass -class AsyncWebsocketsAudioTranscriptionsCreateClient(AsyncWebsocketsBaseClient): +class WebsocketsAudioTranscriptionsClient(WebsocketsBaseClient): def __init__( self, base_url: str, auth: Auth, requester: Requester, - on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): - if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED: on_event.on_input_audio_buffer_committed, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, - WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, - } + if isinstance(on_event, WebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) super().__init__( base_url=base_url, auth=auth, requester=requester, path="v1/audio/transcriptions", - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED], **kwargs, ) - async def update(self, data: TranscriptionsUpdateEvent.InputAudio) -> None: - await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + def transcriptions_update(self, data: TranscriptionsUpdateEvent.Data) -> None: + self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( { - "data": InputAudioBufferAppendEvent.Data.model_validate( + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE.value: + return TranscriptionsMessageUpdateEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": TranscriptionsMessageUpdateEvent.Data.model_validate( { - "delta": delta, + "content": data.get("content") or "", } - ) + ), + } + ) + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, } ) + else: + log_warning("[v1/audio/transcriptions] unknown event=%s, logid=%s", event_type, detail.logid) + return None + + +class WebsocketsAudioTranscriptionsBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsAudioTranscriptionsClient: + return WebsocketsAudioTranscriptionsClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + on_event=on_event, # type: ignore + **kwargs, ) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + +class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_transcriptions_message_update( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + ): + pass + + async def on_transcriptions_message_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent + ): + pass + + +class AsyncWebsocketsAudioTranscriptionsClient(AsyncWebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE: on_event.on_transcriptions_message_update, + WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED: on_event.on_transcriptions_message_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/audio/transcriptions", + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED], + **kwargs, + ) + + async def transcriptions_update(self, data: InputAudio) -> None: + await self._input_queue.put(TranscriptionsUpdateEvent.model_validate({"data": data})) + + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) data = message.get("data") or {} - if event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: - return TranscriptionsMessageCompletedEvent.model_validate( + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( { - "event_id": event_id, + "id": event_id, + "detail": detail, } ) elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_UPDATE.value: return TranscriptionsMessageUpdateEvent.model_validate( { - "event_id": event_id, + "id": event_id, + "detail": detail, "data": TranscriptionsMessageUpdateEvent.Data.model_validate( { "content": data.get("content") or "", @@ -153,14 +251,19 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ), } ) - elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMMITTED.value: - pass + elif event_type == WebsocketsEventType.TRANSCRIPTIONS_MESSAGE_COMPLETED.value: + return TranscriptionsMessageCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) else: - log_warning("[v1/audio/transcriptions] unknown event=%s", event_type) + log_warning("[v1/audio/transcriptions] unknown event=%s, logid=%s", event_type, detail.logid) return None -class AsyncWebsocketsAudioTranscriptionsClient: +class AsyncWebsocketsAudioTranscriptionsBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -171,11 +274,11 @@ def create( *, on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsAudioTranscriptionsCreateClient: - return AsyncWebsocketsAudioTranscriptionsCreateClient( + ) -> AsyncWebsocketsAudioTranscriptionsClient: + return AsyncWebsocketsAudioTranscriptionsClient( base_url=self._base_url, auth=self._auth, requester=self._requester, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 1bcb1b9..2dc1722 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -1,66 +1,265 @@ -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Union -from cozepy import Chat, Message +from pydantic import BaseModel + +from cozepy import Chat, Message, ToolOutput from cozepy.auth import Auth from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash -from cozepy.websockets.audio.transcriptions import InputAudioBufferAppendEvent, InputAudioBufferCommitEvent +from cozepy.websockets.audio.transcriptions import ( + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, + InputAudioBufferCompleteEvent, +) from cozepy.websockets.ws import ( AsyncWebsocketsBaseClient, AsyncWebsocketsBaseEventHandler, + InputAudio, + OutputAudio, + WebsocketsBaseClient, + WebsocketsBaseEventHandler, WebsocketsEvent, WebsocketsEventType, ) +# req +class ChatUpdateEvent(WebsocketsEvent): + class ChatConfig(BaseModel): + conversation_id: Optional[str] = None + user_id: Optional[str] = None + meta_data: Optional[Dict[str, str]] = None + custom_variables: Optional[Dict[str, str]] = None + extra_params: Optional[Dict[str, str]] = None + auto_save_history: Optional[bool] = None + + class Data(BaseModel): + output_audio: Optional[OutputAudio] = None + input_audio: Optional[InputAudio] = None + chat_config: Optional["ChatUpdateEvent.ChatConfig"] = None + + event_type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATE + data: Data + + +# req +class ConversationChatSubmitToolOutputsEvent(WebsocketsEvent): + class Data(BaseModel): + chat_id: str + tool_outputs: List[ToolOutput] + + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS + data: Data + + # resp class ConversationChatCreatedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED data: Chat # resp class ConversationMessageDeltaEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_MESSAGE_DELTA + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_MESSAGE_DELTA data: Message +# resp +class ConversationChatRequiresActionEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION + data: Chat + + # resp class ConversationAudioDeltaEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_DELTA + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_DELTA data: Message # resp class ConversationChatCompletedEvent(WebsocketsEvent): - type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_COMPLETED + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_COMPLETED data: Chat +class WebsocketsChatEventHandler(WebsocketsBaseEventHandler): + def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent): + pass + + def on_conversation_chat_created(self, cli: "WebsocketsChatClient", event: ConversationChatCreatedEvent): + pass + + def on_conversation_message_delta(self, cli: "WebsocketsChatClient", event: ConversationMessageDeltaEvent): + pass + + def on_conversation_chat_requires_action( + self, cli: "WebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass + + def on_conversation_audio_delta(self, cli: "WebsocketsChatClient", event: ConversationAudioDeltaEvent): + pass + + def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: ConversationChatCompletedEvent): + pass + + +class WebsocketsChatClient(WebsocketsBaseClient): + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ): + if isinstance(on_event, WebsocketsChatEventHandler): + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) + super().__init__( + base_url=base_url, + auth=auth, + requester=requester, + path="v1/chat", + query={ + "bot_id": bot_id, + }, + on_event=on_event, # type: ignore + wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], + **kwargs, + ) + + def chat_update(self, data: ChatUpdateEvent.Data) -> None: + self._input_queue.put(ChatUpdateEvent.model_validate({"data": data})) + + def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data): + self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) + + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + def input_audio_buffer_complete(self) -> None: + self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) + + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" + data = message.get("data") or {} + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + return ConversationChatCreatedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: + return ConversationMessageDeltaEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: + return ConversationAudioDeltaEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Message.model_validate(data), + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: + return ConversationChatCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), + } + ) + else: + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) + return None + + +class WebsocketsChatBuildClient(object): + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + + def create( + self, + *, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, + ) -> WebsocketsChatClient: + return WebsocketsChatClient( + base_url=self._base_url, + auth=self._auth, + requester=self._requester, + bot_id=bot_id, + on_event=on_event, # type: ignore + **kwargs, + ) + + class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler): - def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCreatedEvent): + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent + ): + pass + + async def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCreatedEvent): pass - def on_conversation_message_delta( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationMessageDeltaEvent + async def on_conversation_message_delta( + self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageDeltaEvent ): pass - # def on_conversation_chat_requires_action(self, cli: 'AsyncWebsocketsChatCreateClient', - # event: ConversationChatRequiresActionEvent): - # pass + async def on_conversation_chat_requires_action( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent + ): + pass - def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationAudioDeltaEvent): + async def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioDeltaEvent): pass - def on_conversation_chat_completed( - self, cli: "AsyncWebsocketsChatCreateClient", event: ConversationChatCompletedEvent + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent ): pass -class AsyncWebsocketsChatCreateClient(AsyncWebsocketsBaseClient): +class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient): def __init__( self, base_url: str, @@ -71,15 +270,16 @@ def __init__( **kwargs, ): if isinstance(on_event, AsyncWebsocketsChatEventHandler): - on_event = { - WebsocketsEventType.ERROR: on_event.on_error, - WebsocketsEventType.CLOSED: on_event.on_closed, - WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, - WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, - # EventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, - WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, - WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, - } + on_event = on_event.to_dict( + { + WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, + WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, + WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, + WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, + WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + } + ) super().__init__( base_url=base_url, auth=auth, @@ -88,68 +288,81 @@ def __init__( query={ "bot_id": bot_id, }, - on_event=on_event, + on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], **kwargs, ) - # async def update(self, event: TranscriptionsUpdateEventInputAudio) -> None: - # await self._input_queue.put(TranscriptionsUpdateEvent.load(event)) + async def chat_update(self, data: ChatUpdateEvent.Data) -> None: + await self._input_queue.put(ChatUpdateEvent.model_validate({"data": data})) - async def append(self, delta: bytes) -> None: - await self._input_queue.put( - InputAudioBufferAppendEvent.model_validate( - { - "data": InputAudioBufferAppendEvent.Data.model_validate( - { - "delta": delta, - } - ) - } - ) - ) + async def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data) -> None: + await self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) - async def commit(self) -> None: - await self._input_queue.put(InputAudioBufferCommitEvent.model_validate({})) + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: + await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) + + async def input_audio_buffer_complete(self) -> None: + await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: event_id = message.get("event_id") or "" - event_type = message.get("type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + event_type = message.get("event_type") or "" data = message.get("data") or {} - if event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: + if event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: + return InputAudioBufferCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CREATED.value: return ConversationChatCreatedEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Chat.model_validate(data), - "event_id": event_id, } ) if event_type == WebsocketsEventType.CONVERSATION_MESSAGE_DELTA.value: return ConversationMessageDeltaEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Message.model_validate(data), - "event_id": event_id, + } + ) + if event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: + return ConversationChatRequiresActionEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": Chat.model_validate(data), } ) elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_DELTA.value: return ConversationAudioDeltaEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Message.model_validate(data), - "event_id": event_id, } ) elif event_type == WebsocketsEventType.CONVERSATION_CHAT_COMPLETED.value: return ConversationChatCompletedEvent.model_validate( { + "id": event_id, + "detail": detail, "data": Chat.model_validate(data), - "event_id": event_id, } ) else: - log_warning("[%s] unknown event=%s", self._path, event_type) + log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None -class AsyncWebsocketsChatClient(object): +class AsyncWebsocketsChatBuildClient(object): def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth @@ -161,12 +374,12 @@ def create( bot_id: str, on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, - ) -> AsyncWebsocketsChatCreateClient: - return AsyncWebsocketsChatCreateClient( + ) -> AsyncWebsocketsChatClient: + return AsyncWebsocketsChatClient( base_url=self._base_url, auth=self._auth, requester=self._requester, bot_id=bot_id, - on_event=on_event, + on_event=on_event, # type: ignore **kwargs, ) diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 9e787a6..f9d095c 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -1,17 +1,42 @@ import abc import asyncio import json +import queue +import sys +import threading +import traceback from abc import ABC -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, contextmanager from enum import Enum -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Set -import websockets -from websockets import InvalidStatus -from websockets.asyncio.connection import Connection +if sys.version_info >= (3, 8): + # note: >=3.7,<3.8 not support asyncio + from websockets import InvalidStatus + from websockets.asyncio.client import ClientConnection as AsyncWebsocketClientConnection + from websockets.asyncio.client import connect as asyncio_connect +else: + # 警告: 当前Python版本不支持asyncio websockets + # 如果Python版本小于3.8,则不支持异步websockets功能 + import warnings + + warnings.warn("asyncio websockets requires Python >= 3.8") + + class AsyncWebsocketClientConnection(object): + pass + + def asyncio_connect(*args, **kwargs): + pass + + class InvalidStatus(object): + pass + + +import websockets.sync.client +from pydantic import BaseModel from cozepy import Auth, CozeAPIError -from cozepy.log import log_debug, log_info +from cozepy.log import log_debug, log_error, log_info from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -20,49 +45,311 @@ class WebsocketsEventType(str, Enum): # common - ERROR = "error" - CLOSED = "closed" + CLIENT_ERROR = "client_error" # sdk error + CLOSED = "closed" # connection closed + + # error + ERROR = "error" # received error event # v1/audio/speech # req - INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" - INPUT_TEXT_BUFFER_COMMIT = "input_text_buffer.commit" - SPEECH_UPDATE = "speech.update" + INPUT_TEXT_BUFFER_APPEND = "input_text_buffer.append" # send text to server + INPUT_TEXT_BUFFER_COMPLETE = ( + "input_text_buffer.complete" # no text to send, after audio all received, can close connection + ) + SPEECH_UPDATE = "speech.update" # send speech config to server # resp # v1/audio/speech - INPUT_TEXT_BUFFER_COMMITTED = "input_text_buffer.committed" # ignored - SPEECH_AUDIO_UPDATE = "speech.audio.update" - SPEECH_AUDIO_COMPLETED = "speech.audio.completed" + INPUT_TEXT_BUFFER_COMPLETED = "input_text_buffer.completed" # received `input_text_buffer.complete` event + SPEECH_AUDIO_UPDATE = "speech.audio.update" # received `speech.update` event + SPEECH_AUDIO_COMPLETED = "speech.audio.completed" # all audio received, can close connection # v1/audio/transcriptions # req - INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - TRANSCRIPTIONS_UPDATE = "transcriptions.update" + INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + INPUT_AUDIO_BUFFER_COMPLETE = ( + "input_audio_buffer.complete" # no audio to send, after text all received, can close connection + ) + TRANSCRIPTIONS_UPDATE = "transcriptions.update" # send transcriptions config to server # resp - INPUT_AUDIO_BUFFER_COMMITTED = "input_audio_buffer.committed" # ignored - TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" - TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" + INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + TRANSCRIPTIONS_MESSAGE_UPDATE = "transcriptions.message.update" # received `transcriptions.update` event + TRANSCRIPTIONS_MESSAGE_COMPLETED = "transcriptions.message.completed" # all audio received, can close connection # v1/chat # req - # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" - # INPUT_AUDIO_BUFFER_COMMIT = "input_audio_buffer.commit" - CHAT_UPDATE = "chat.update" + # INPUT_AUDIO_BUFFER_APPEND = "input_audio_buffer.append" # send audio to server + # INPUT_AUDIO_BUFFER_COMPLETE = "input_audio_buffer.complete" # no audio send, start chat + CHAT_UPDATE = "chat.update" # send chat config to server + CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server # resp - CONVERSATION_CHAT_CREATED = "conversation.chat.created" - CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" - CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" - CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" - CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" + # INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event + CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started + CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" # get agent text message update + CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" # need plugin submit + CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update + CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" # all message received, can close connection class WebsocketsEvent(CozeModel, ABC): + class Detail(BaseModel): + logid: Optional[str] = None + + event_type: WebsocketsEventType event_id: Optional[str] = None - type: WebsocketsEventType + detail: Optional[Detail] = None + + +class WebsocketsErrorEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.ERROR + data: CozeAPIError + + +class InputAudio(CozeModel): + format: Optional[str] + codec: Optional[str] + sample_rate: Optional[int] + channel: Optional[int] + bit_depth: Optional[int] + + +class OpusConfig(BaseModel): + bitrate: Optional[int] = None + use_cbr: Optional[bool] = None + frame_size_ms: Optional[float] = None + + +class PCMConfig(BaseModel): + sample_rate: Optional[int] = None + + +class OutputAudio(BaseModel): + codec: Optional[str] + pcm_config: Optional[PCMConfig] = None + opus_config: Optional[OpusConfig] = None + speech_rate: Optional[int] = None + voice_id: Optional[str] = None + + +class WebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + + def __init__( + self, + base_url: str, + auth: Auth, + requester: Requester, + path: str, + query: Optional[Dict[str, str]] = None, + on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, + wait_events: Optional[List[WebsocketsEventType]] = None, + **kwargs, + ): + self._state = self.State.INITIALIZED + self._base_url = remove_url_trailing_slash(base_url) + self._auth = auth + self._requester = requester + self._path = path + self._ws_url = self._base_url + "/" + path + if query: + self._ws_url += "?" + "&".join([f"{k}={v}" for k, v in query.items()]) + self._on_event = on_event.copy() if on_event else {} + self._headers = kwargs.get("headers") + self._wait_events = wait_events.copy() if wait_events else [] + + self._input_queue: queue.Queue[Optional[WebsocketsEvent]] = queue.Queue() + self._ws: Optional[websockets.sync.client.ClientConnection] = None + self._send_thread: Optional[threading.Thread] = None + self._receive_thread: Optional[threading.Thread] = None + self._completed_events: Set[WebsocketsEventType] = set() + self._completed_event = threading.Event() + + @contextmanager + def __call__(self): + try: + self.connect() + yield self + finally: + self.close() + + def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING + headers = { + "Authorization": f"Bearer {self._auth.token}", + "X-Coze-Client-User-Agent": coze_client_user_agent(), + **(self._headers or {}), + } + try: + self._ws = websockets.sync.client.connect( + self._ws_url, + user_agent_header=user_agent(), + additional_headers=headers, + ) + self._state = self.State.CONNECTED + log_info("[%s] connected to websocket", self._path) + + self._send_thread = threading.Thread(target=self._send_loop) + self._receive_thread = threading.Thread(target=self._receive_loop) + self._send_thread.start() + self._receive_thread.start() + except InvalidStatus as e: + raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e + + def wait(self, events: Optional[List[WebsocketsEventType]] = None, wait_all=True) -> None: + if events is None: + events = self._wait_events + self._wait_completed(events, wait_all=wait_all) + + def on(self, event_type: WebsocketsEventType, handler: Callable): + self._on_event[event_type] = handler + + def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING + self._close() + self._state = self.State.CLOSED + + def _send_loop(self) -> None: + try: + while True: + event = self._input_queue.get() + self._send_event(event) + self._input_queue.task_done() + except Exception as e: + self._handle_error(e) + + def _receive_loop(self) -> None: + try: + while True: + if not self._ws: + log_debug("[%s] empty websocket conn, close", self._path) + break + + data = self._ws.recv() + message = json.loads(data) + event_type = message.get("event_type") + log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) + + event = self._load_all_event(message) + if event: + handler = self._on_event.get(event_type) + if handler: + handler(self, event) + self._completed_events.add(event_type) + self._completed_event.set() + except Exception as e: + self._handle_error(e) + + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": CozeAPIError(code, msg, WebsocketsEvent.Detail.model_validate(detail).logid), + } + ) + return self._load_event(message) + + @abc.abstractmethod + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... + + def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + while True: + if wait_all: + # 所有事件都需要完成 + if self._completed_events == set(events): + break + else: + # 任意一个事件完成即可 + if any(event in self._completed_events for event in events): + break + self._completed_event.wait() + self._completed_event.clear() + + def _handle_error(self, error: Exception) -> None: + handler = self._on_event.get(WebsocketsEventType.ERROR) + if handler: + handler(self, error) + else: + raise error + + def _close(self) -> None: + log_info("[%s] connect closed", self._path) + if self._send_thread: + self._send_thread.join() + if self._receive_thread: + self._receive_thread.join() + + if self._ws: + self._ws.close() + self._ws = None + + while not self._input_queue.empty(): + self._input_queue.get() + + handler = self._on_event.get(WebsocketsEventType.CLOSED) + if handler: + handler(self) + + def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None: + if not event or not self._ws: + return + log_debug("[%s] send event, type=%s", self._path, event.event_type.value) + self._ws.send(event.model_dump_json()) + + +class WebsocketsBaseEventHandler(object): + def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + def on_closed(self, cli: "WebsocketsBaseClient"): + pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res class AsyncWebsocketsBaseClient(abc.ABC): + class State(str, Enum): + """ + initialized, connecting, connected, closing, closed + """ + + INITIALIZED = "initialized" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + def __init__( self, base_url: str, @@ -74,6 +361,7 @@ def __init__( wait_events: Optional[List[WebsocketsEventType]] = None, **kwargs, ): + self._state = self.State.INITIALIZED self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester @@ -86,7 +374,7 @@ def __init__( self._wait_events = wait_events.copy() if wait_events else [] self._input_queue: asyncio.Queue[Optional[WebsocketsEvent]] = asyncio.Queue() - self._ws: Optional[Connection] = None + self._ws: Optional[AsyncWebsocketClientConnection] = None self._send_task: Optional[asyncio.Task] = None self._receive_task: Optional[asyncio.Task] = None @@ -99,21 +387,25 @@ async def __call__(self): await self.close() async def connect(self): + if self._state != self.State.INITIALIZED: + raise ValueError(f"Cannot connect in {self._state.value} state") + self._state = self.State.CONNECTING headers = { "Authorization": f"Bearer {self._auth.token}", "X-Coze-Client-User-Agent": coze_client_user_agent(), **(self._headers or {}), } try: - self._ws = await websockets.connect( + self._ws = await asyncio_connect( self._ws_url, user_agent_header=user_agent(), additional_headers=headers, ) + self._state = self.State.CONNECTED log_info("[%s] connected to websocket", self._path) - self._receive_task = asyncio.create_task(self._receive_loop()) self._send_task = asyncio.create_task(self._send_loop()) + self._receive_task = asyncio.create_task(self._receive_loop()) except InvalidStatus as e: raise CozeAPIError(None, f"{e}", e.response.headers.get("x-tt-logid")) from e @@ -126,13 +418,17 @@ def on(self, event_type: WebsocketsEventType, handler: Callable): self._on_event[event_type] = handler async def close(self) -> None: + if self._state not in (self.State.CONNECTED, self.State.CONNECTING): + return + self._state = self.State.CLOSING await self._close() + self._state = self.State.CLOSED async def _send_loop(self) -> None: try: while True: - if event := await self._input_queue.get(): - await self._send_event(event) + event = await self._input_queue.get() + await self._send_event(event) self._input_queue.task_done() except Exception as e: await self._handle_error(e) @@ -146,53 +442,79 @@ async def _receive_loop(self) -> None: data = await self._ws.recv() message = json.loads(data) - event_type = message.get("type") + event_type = message.get("event_type") log_debug("[%s] receive event, type=%s, event=%s", self._path, event_type, data) - if handler := self._on_event.get(event_type): - if event := self._load_event(message): - await handler(self, event) + handler = self._on_event.get(event_type) + event = self._load_all_event(message) + if handler and event: + await handler(self, event) except Exception as e: await self._handle_error(e) + def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: + event_id = message.get("event_id") or "" + event_type = message.get("event_type") or "" + detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) + data = message.get("data") or {} + if event_type == WebsocketsEventType.ERROR.value: + code, msg = data.get("code") or 0, data.get("msg") or "" + return WebsocketsErrorEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": CozeAPIError(code, msg, WebsocketsEvent.Detail.model_validate(detail).logid), + } + ) + return self._load_event(message) + @abc.abstractmethod def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... - async def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: + async def _wait_completed(self, wait_events: List[WebsocketsEventType], wait_all: bool) -> None: future: asyncio.Future[None] = asyncio.Future() - original_handlers = {} completed_events = set() - async def _handle_completed(client, event): - event_type = event.type - completed_events.add(event_type) + def _wrap_handler(event_type: WebsocketsEventType, original_handler): + async def wrapped(client, event): + # 先执行原始处理函数 + if original_handler: + await original_handler(client, event) + + # 再检查完成条件 + completed_events.add(event_type) + if wait_all: + # 所有事件都需要完成 + if completed_events == set(wait_events): + if not future.done(): + future.set_result(None) + else: + # 任意一个事件完成即可 + if not future.done(): + future.set_result(None) - if wait_all: - # 所有事件都需要完成 - if completed_events == set(events): - future.set_result(None) - else: - # 任意一个事件完成即可 - future.set_result(None) + return wrapped - # 为每个指定的事件类型临时注册处理函数 - for event_type in events: - original_handlers[event_type] = self._on_event.get(event_type) - self._on_event[event_type] = _handle_completed + # 为每个指定的事件类型包装处理函数 + origin_handlers = {} + for event_type in wait_events: + original_handler = self._on_event.get(event_type) + origin_handlers[event_type] = original_handler + self._on_event[event_type] = _wrap_handler(event_type, original_handler) try: # 等待直到满足完成条件 await future finally: # 恢复所有原来的处理函数 - for event_type, handler in original_handlers.items(): - if handler: - self._on_event[event_type] = handler - else: - self._on_event.pop(event_type, None) + for event_type in wait_events: + original_handler = origin_handlers.get(event_type) + if original_handler: + self._on_event[event_type] = original_handler async def _handle_error(self, error: Exception) -> None: - if handler := self._on_event.get(WebsocketsEventType.ERROR): + handler = self._on_event.get(WebsocketsEventType.ERROR) + if handler: await handler(self, error) else: raise error @@ -211,18 +533,34 @@ async def _close(self) -> None: while not self._input_queue.empty(): await self._input_queue.get() - if handler := self._on_event.get(WebsocketsEventType.CLOSED): + handler = self._on_event.get(WebsocketsEventType.CLOSED) + if handler: await handler(self) - async def _send_event(self, event: WebsocketsEvent) -> None: - log_debug("[%s] send event, type=%s", self._path, event.type.value) - if self._ws: - await self._ws.send(event.model_dump_json(), True) + async def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None: + if not event or not self._ws: + return + log_debug("[%s] send event, type=%s", self._path, event.event_type.value) + await self._ws.send(event.model_dump_json()) class AsyncWebsocketsBaseEventHandler(object): - async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: Exception): - pass + async def on_client_error(self, cli: "WebsocketsBaseClient", e: Exception): + log_error(f"Client Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") + + async def on_error(self, cli: "AsyncWebsocketsBaseClient", e: CozeAPIError): + log_error(f"Error occurred: {str(e)}") + log_error(f"Stack trace:\n{traceback.format_exc()}") async def on_closed(self, cli: "AsyncWebsocketsBaseClient"): pass + + def to_dict(self, origin: Dict[WebsocketsEventType, Callable]): + res = { + WebsocketsEventType.CLIENT_ERROR: self.on_client_error, + WebsocketsEventType.ERROR: self.on_error, + WebsocketsEventType.CLOSED: self.on_closed, + } + res.update(origin) + return res diff --git a/cozepy/workflows/runs/__init__.py b/cozepy/workflows/runs/__init__.py index d9dec93..a774941 100644 --- a/cozepy/workflows/runs/__init__.py +++ b/cozepy/workflows/runs/__init__.py @@ -253,7 +253,7 @@ def resume( url = f"{self._base_url}/v1/workflow/stream_resume" body = { "workflow_id": workflow_id, - "event_id": event_id, + "id": event_id, "resume_data": resume_data, "interrupt_type": interrupt_type, } @@ -385,7 +385,7 @@ async def resume( url = f"{self._base_url}/v1/workflow/stream_resume" body = { "workflow_id": workflow_id, - "event_id": event_id, + "id": event_id, "resume_data": resume_data, "interrupt_type": interrupt_type, } diff --git a/examples/utils/__init__.py b/examples/utils/__init__.py index 9aeb7f5..8753676 100644 --- a/examples/utils/__init__.py +++ b/examples/utils/__init__.py @@ -1,7 +1,8 @@ +import logging import os from typing import Optional -from cozepy import COZE_CN_BASE_URL, DeviceOAuthApp +from cozepy import COZE_CN_BASE_URL, DeviceOAuthApp, setup_logging def get_coze_api_base() -> str: @@ -26,3 +27,9 @@ def get_coze_api_token(workspace_id: Optional[str] = None) -> str: device_code = device_oauth_app.get_device_code(workspace_id) print(f"Please Open: {device_code.verification_url} to get the access token") return device_oauth_app.get_access_token(device_code=device_code.device_code, poll=True).access_token + + +def setup_examples_logger(): + coze_log = os.getenv("COZE_LOG") + if coze_log: + setup_logging(logging.getLevelNamesMapping().get(coze_log.upper(), logging.INFO)) diff --git a/examples/websockets_audio_speech.py b/examples/websockets_audio_speech.py index d960f51..108bb50 100644 --- a/examples/websockets_audio_speech.py +++ b/examples/websockets_audio_speech.py @@ -1,39 +1,51 @@ import asyncio import json -import logging import os from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioSpeechCreateClient, + AsyncWebsocketsAudioSpeechClient, AsyncWebsocketsAudioSpeechEventHandler, + InputTextBufferAppendEvent, + InputTextBufferCompletedEvent, + SpeechAudioCompletedEvent, SpeechAudioUpdateEvent, TokenAuth, - setup_logging, ) +from cozepy.log import log_info from cozepy.util import write_pcm_to_wav_file -from examples.utils import get_coze_api_base, get_coze_api_token +from examples.utils import get_coze_api_base, get_coze_api_token, setup_examples_logger -coze_log = os.getenv("COZE_LOG") -if coze_log: - setup_logging(logging.getLevelNamesMapping()[coze_log.upper()]) +setup_examples_logger() kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") +# todo review class AsyncWebsocketsAudioSpeechEventHandlerSub(AsyncWebsocketsAudioSpeechEventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechCreateClient, event: SpeechAudioUpdateEvent): + async def on_input_text_buffer_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + ): + log_info("[examples] Input text buffer completed") + + async def on_speech_audio_update(self, cli: AsyncWebsocketsAudioSpeechClient, event: SpeechAudioUpdateEvent): self.delta.append(event.data.delta) - async def on_error(self, cli: AsyncWebsocketsAudioSpeechCreateClient, e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: AsyncWebsocketsAudioSpeechClient, e: Exception): + log_info("[examples] Error occurred: %s", e) - async def on_closed(self, cli: AsyncWebsocketsAudioSpeechCreateClient): - print("Speech connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_speech_audio_completed( + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") + self.delta = [] async def main(): @@ -55,8 +67,14 @@ async def main(): text = "你今天好吗? 今天天气不错呀" async with speech() as client: - await client.append(text) - await client.commit() + await client.input_text_buffer_append( + InputTextBufferAppendEvent.Data.model_validate( + { + "delta": text, + } + ) + ) + await client.input_text_buffer_complete() await client.wait() diff --git a/examples/websockets_audio_transcriptions.py b/examples/websockets_audio_transcriptions.py index 3dc47a6..9828dd3 100644 --- a/examples/websockets_audio_transcriptions.py +++ b/examples/websockets_audio_transcriptions.py @@ -1,22 +1,22 @@ import asyncio import json -import logging import os from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, + AsyncWebsocketsAudioTranscriptionsClient, AsyncWebsocketsAudioTranscriptionsEventHandler, AudioFormat, + InputAudioBufferAppendEvent, + InputAudioBufferCompletedEvent, TokenAuth, TranscriptionsMessageUpdateEvent, - setup_logging, ) -from examples.utils import get_coze_api_base, get_coze_api_token +from cozepy.log import log_info +from examples.utils import get_coze_api_base, get_coze_api_token, setup_examples_logger + +setup_examples_logger() -coze_log = os.getenv("COZE_LOG") -if coze_log: - setup_logging(logging.getLevelNamesMapping()[coze_log.upper()]) kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") @@ -25,16 +25,21 @@ class AudioTranscriptionsEventHandlerSub(AsyncWebsocketsAudioTranscriptionsEvent Class is not required, you can also use Dict to set callback """ - async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient"): - print("Connection closed") + async def on_closed(self, cli: "AsyncWebsocketsAudioTranscriptionsClient"): + log_info("[examples] Connect closed") - async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", e: Exception): - print(f"Error occurred: {e}") + async def on_error(self, cli: "AsyncWebsocketsAudioTranscriptionsClient", e: Exception): + log_info("[examples] Error occurred: %s", e) async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsCreateClient", event: TranscriptionsMessageUpdateEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + ): + log_info("[examples] Received: %s", event.data.content) + + async def on_input_audio_buffer_completed( + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): - print("Received:", event.data.content) + log_info("[examples] Input audio buffer completed") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -72,9 +77,15 @@ async def main(): # Create and connect WebSocket client async with transcriptions() as client: - async for data in speech_stream(): - await client.append(data) - await client.commit() + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/examples/websockets_chat.py b/examples/websockets_chat.py index b95d93e..7382e79 100644 --- a/examples/websockets_chat.py +++ b/examples/websockets_chat.py @@ -1,56 +1,82 @@ import asyncio import json -import logging import os from cozepy import ( AsyncCoze, - AsyncWebsocketsAudioTranscriptionsCreateClient, - AsyncWebsocketsChatCreateClient, + AsyncWebsocketsChatClient, AsyncWebsocketsChatEventHandler, AudioFormat, ConversationAudioDeltaEvent, + ConversationChatCompletedEvent, ConversationChatCreatedEvent, + ConversationChatRequiresActionEvent, + ConversationChatSubmitToolOutputsEvent, ConversationMessageDeltaEvent, + InputAudioBufferAppendEvent, TokenAuth, - setup_logging, + ToolOutput, ) from cozepy.log import log_info from cozepy.util import write_pcm_to_wav_file -from examples.utils import get_coze_api_base, get_coze_api_token +from examples.utils import get_coze_api_base, get_coze_api_token, setup_examples_logger -coze_log = os.getenv("COZE_LOG") -if coze_log: - setup_logging(logging.getLevelNamesMapping()[coze_log.upper()]) +setup_examples_logger() kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") class AsyncWebsocketsChatEventHandlerSub(AsyncWebsocketsChatEventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + delta = [] - async def on_conversation_chat_created( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationChatCreatedEvent - ): - log_info("ChatCreated") + async def on_error(self, cli: AsyncWebsocketsChatClient, e: Exception): + import traceback - async def on_conversation_message_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationMessageDeltaEvent - ): + log_info(f"Error occurred: {str(e)}") + log_info(f"Stack trace:\n{traceback.format_exc()}") + + async def on_conversation_chat_created(self, cli: AsyncWebsocketsChatClient, event: ConversationChatCreatedEvent): + log_info("[examples] asr completed, logid=%s", event.detail.logid) + + async def on_conversation_message_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationMessageDeltaEvent): print("Received:", event.data.content) - async def on_conversation_audio_delta( - self, cli: AsyncWebsocketsChatCreateClient, event: ConversationAudioDeltaEvent + async def on_conversation_chat_requires_action( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent ): - self.delta.append(event.data.get_audio()) + def fake_run_local_plugin(): + # this is just fake outputs + return event.data.required_action.submit_tool_outputs.tool_calls[0].function.arguments + + fake_output = fake_run_local_plugin() + await cli.conversation_chat_submit_tool_outputs( + ConversationChatSubmitToolOutputsEvent.Data.model_validate( + { + "chat_id": event.data.id, + "tool_outputs": [ + ToolOutput.model_validate( + { + "tool_call_id": event.data.required_action.submit_tool_outputs.tool_calls[0].id, + "output": fake_output, + } + ) + ], + } + ) + ) - async def on_error(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient, e: Exception): - log_info(f"Error occurred: {str(e)}") + async def on_conversation_audio_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationAudioDeltaEvent): + self.delta.append(event.data.get_audio()) - async def on_closed(self, cli: AsyncWebsocketsAudioTranscriptionsCreateClient): - print("Chat connection closed, saving audio data to output.wav") - audio_data = b"".join(self.delta) - write_pcm_to_wav_file(audio_data, "output.wav") + async def on_conversation_chat_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent + ): + log_info("[examples] Saving audio data to output.wav") + write_pcm_to_wav_file(b"".join(self.delta), "output.wav") def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): @@ -73,6 +99,7 @@ async def main(): coze_api_token = get_coze_api_token() coze_api_base = get_coze_api_base() bot_id = os.getenv("COZE_BOT_ID") + text = os.getenv("COZE_TEXT") or "你今天好吗? 今天天气不错呀" # Initialize Coze client coze = AsyncCoze( @@ -80,7 +107,7 @@ async def main(): base_url=coze_api_base, ) # Initialize Audio - speech_stream = wrap_coze_speech_to_iterator(coze, "你今天好吗? 今天天气不错呀") + speech_stream = wrap_coze_speech_to_iterator(coze, text) chat = coze.websockets.chat.create( bot_id=bot_id, @@ -91,10 +118,15 @@ async def main(): # Create and connect WebSocket client async with chat() as client: # Read and send audio data - async for data in speech_stream(): - await client.append(data) - await client.commit() - log_info("Audio Committed") + async for delta in speech_stream(): + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await client.input_audio_buffer_complete() await client.wait() diff --git a/poetry.lock b/poetry.lock index bd8b9cf..8ae8e12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1067,37 +1067,84 @@ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "s test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] [[package]] -name = "websocket-client" -version = "1.6.1" -description = "WebSocket client for Python with low level API options" +name = "websockets" +version = "11.0.3" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.7" files = [ - {file = "websocket-client-1.6.1.tar.gz", hash = "sha256:c951af98631d24f8df89ab1019fc365f2227c0892f12fd150e935607c79dd0dd"}, - {file = "websocket_client-1.6.1-py3-none-any.whl", hash = "sha256:f1f9f2ad5291f0225a49efad77abf9e700b6fef553900623060dad6e26503b9d"}, -] - -[package.extras] -docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"] -optional = ["python-socks", "wsaccel"] -test = ["websockets"] - -[[package]] -name = "websocket-client" -version = "1.8.0" -description = "WebSocket client for Python with low level API options" -optional = false -python-versions = ">=3.8" -files = [ - {file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"}, - {file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"}, + {file = "websockets-11.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3ccc8a0c387629aec40f2fc9fdcb4b9d5431954f934da3eaf16cdc94f67dbfac"}, + {file = "websockets-11.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d67ac60a307f760c6e65dad586f556dde58e683fab03323221a4e530ead6f74d"}, + {file = "websockets-11.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:84d27a4832cc1a0ee07cdcf2b0629a8a72db73f4cf6de6f0904f6661227f256f"}, + {file = "websockets-11.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffd7dcaf744f25f82190856bc26ed81721508fc5cbf2a330751e135ff1283564"}, + {file = "websockets-11.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7622a89d696fc87af8e8d280d9b421db5133ef5b29d3f7a1ce9f1a7bf7fcfa11"}, + {file = "websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bceab846bac555aff6427d060f2fcfff71042dba6f5fca7dc4f75cac815e57ca"}, + {file = "websockets-11.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:54c6e5b3d3a8936a4ab6870d46bdd6ec500ad62bde9e44462c32d18f1e9a8e54"}, + {file = "websockets-11.0.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:41f696ba95cd92dc047e46b41b26dd24518384749ed0d99bea0a941ca87404c4"}, + {file = "websockets-11.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:86d2a77fd490ae3ff6fae1c6ceaecad063d3cc2320b44377efdde79880e11526"}, + {file = "websockets-11.0.3-cp310-cp310-win32.whl", hash = "sha256:2d903ad4419f5b472de90cd2d40384573b25da71e33519a67797de17ef849b69"}, + {file = "websockets-11.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:1d2256283fa4b7f4c7d7d3e84dc2ece74d341bce57d5b9bf385df109c2a1a82f"}, + {file = "websockets-11.0.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e848f46a58b9fcf3d06061d17be388caf70ea5b8cc3466251963c8345e13f7eb"}, + {file = "websockets-11.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aa5003845cdd21ac0dc6c9bf661c5beddd01116f6eb9eb3c8e272353d45b3288"}, + {file = "websockets-11.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b58cbf0697721120866820b89f93659abc31c1e876bf20d0b3d03cef14faf84d"}, + {file = "websockets-11.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:660e2d9068d2bedc0912af508f30bbeb505bbbf9774d98def45f68278cea20d3"}, + {file = "websockets-11.0.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1f0524f203e3bd35149f12157438f406eff2e4fb30f71221c8a5eceb3617b6b"}, + {file = "websockets-11.0.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:def07915168ac8f7853812cc593c71185a16216e9e4fa886358a17ed0fd9fcf6"}, + {file = "websockets-11.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b30c6590146e53149f04e85a6e4fcae068df4289e31e4aee1fdf56a0dead8f97"}, + {file = "websockets-11.0.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:619d9f06372b3a42bc29d0cd0354c9bb9fb39c2cbc1a9c5025b4538738dbffaf"}, + {file = "websockets-11.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:01f5567d9cf6f502d655151645d4e8b72b453413d3819d2b6f1185abc23e82dd"}, + {file = "websockets-11.0.3-cp311-cp311-win32.whl", hash = "sha256:e1459677e5d12be8bbc7584c35b992eea142911a6236a3278b9b5ce3326f282c"}, + {file = "websockets-11.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:e7837cb169eca3b3ae94cc5787c4fed99eef74c0ab9506756eea335e0d6f3ed8"}, + {file = "websockets-11.0.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9f59a3c656fef341a99e3d63189852be7084c0e54b75734cde571182c087b152"}, + {file = "websockets-11.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2529338a6ff0eb0b50c7be33dc3d0e456381157a31eefc561771ee431134a97f"}, + {file = "websockets-11.0.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34fd59a4ac42dff6d4681d8843217137f6bc85ed29722f2f7222bd619d15e95b"}, + {file = "websockets-11.0.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:332d126167ddddec94597c2365537baf9ff62dfcc9db4266f263d455f2f031cb"}, + {file = "websockets-11.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6505c1b31274723ccaf5f515c1824a4ad2f0d191cec942666b3d0f3aa4cb4007"}, + {file = "websockets-11.0.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f467ba0050b7de85016b43f5a22b46383ef004c4f672148a8abf32bc999a87f0"}, + {file = "websockets-11.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9d9acd80072abcc98bd2c86c3c9cd4ac2347b5a5a0cae7ed5c0ee5675f86d9af"}, + {file = "websockets-11.0.3-cp37-cp37m-win32.whl", hash = "sha256:e590228200fcfc7e9109509e4d9125eace2042fd52b595dd22bbc34bb282307f"}, + {file = "websockets-11.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:b16fff62b45eccb9c7abb18e60e7e446998093cdcb50fed33134b9b6878836de"}, + {file = "websockets-11.0.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fb06eea71a00a7af0ae6aefbb932fb8a7df3cb390cc217d51a9ad7343de1b8d0"}, + {file = "websockets-11.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8a34e13a62a59c871064dfd8ffb150867e54291e46d4a7cf11d02c94a5275bae"}, + {file = "websockets-11.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4841ed00f1026dfbced6fca7d963c4e7043aa832648671b5138008dc5a8f6d99"}, + {file = "websockets-11.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a073fc9ab1c8aff37c99f11f1641e16da517770e31a37265d2755282a5d28aa"}, + {file = "websockets-11.0.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:68b977f21ce443d6d378dbd5ca38621755f2063d6fdb3335bda981d552cfff86"}, + {file = "websockets-11.0.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1a99a7a71631f0efe727c10edfba09ea6bee4166a6f9c19aafb6c0b5917d09c"}, + {file = "websockets-11.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bee9fcb41db2a23bed96c6b6ead6489702c12334ea20a297aa095ce6d31370d0"}, + {file = "websockets-11.0.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4b253869ea05a5a073ebfdcb5cb3b0266a57c3764cf6fe114e4cd90f4bfa5f5e"}, + {file = "websockets-11.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1553cb82942b2a74dd9b15a018dce645d4e68674de2ca31ff13ebc2d9f283788"}, + {file = "websockets-11.0.3-cp38-cp38-win32.whl", hash = "sha256:f61bdb1df43dc9c131791fbc2355535f9024b9a04398d3bd0684fc16ab07df74"}, + {file = "websockets-11.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:03aae4edc0b1c68498f41a6772d80ac7c1e33c06c6ffa2ac1c27a07653e79d6f"}, + {file = "websockets-11.0.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:777354ee16f02f643a4c7f2b3eff8027a33c9861edc691a2003531f5da4f6bc8"}, + {file = "websockets-11.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8c82f11964f010053e13daafdc7154ce7385ecc538989a354ccc7067fd7028fd"}, + {file = "websockets-11.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3580dd9c1ad0701169e4d6fc41e878ffe05e6bdcaf3c412f9d559389d0c9e016"}, + {file = "websockets-11.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f1a3f10f836fab6ca6efa97bb952300b20ae56b409414ca85bff2ad241d2a61"}, + {file = "websockets-11.0.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df41b9bc27c2c25b486bae7cf42fccdc52ff181c8c387bfd026624a491c2671b"}, + {file = "websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:279e5de4671e79a9ac877427f4ac4ce93751b8823f276b681d04b2156713b9dd"}, + {file = "websockets-11.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1fdf26fa8a6a592f8f9235285b8affa72748dc12e964a5518c6c5e8f916716f7"}, + {file = "websockets-11.0.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:69269f3a0b472e91125b503d3c0b3566bda26da0a3261c49f0027eb6075086d1"}, + {file = "websockets-11.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:97b52894d948d2f6ea480171a27122d77af14ced35f62e5c892ca2fae9344311"}, + {file = "websockets-11.0.3-cp39-cp39-win32.whl", hash = "sha256:c7f3cb904cce8e1be667c7e6fef4516b98d1a6a0635a58a57528d577ac18a128"}, + {file = "websockets-11.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c792ea4eabc0159535608fc5658a74d1a81020eb35195dd63214dcf07556f67e"}, + {file = "websockets-11.0.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e58f2c36cc52d41f2659e4c0cbf7353e28c8c9e63e30d8c6d3494dc9fdedcf"}, + {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de36fe9c02995c7e6ae6efe2e205816f5f00c22fd1fbf343d4d18c3d5ceac2f5"}, + {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0ac56b661e60edd453585f4bd68eb6a29ae25b5184fd5ba51e97652580458998"}, + {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e052b8467dd07d4943936009f46ae5ce7b908ddcac3fda581656b1b19c083d9b"}, + {file = "websockets-11.0.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:42cc5452a54a8e46a032521d7365da775823e21bfba2895fb7b77633cce031bb"}, + {file = "websockets-11.0.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e6316827e3e79b7b8e7d8e3b08f4e331af91a48e794d5d8b099928b6f0b85f20"}, + {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8531fdcad636d82c517b26a448dcfe62f720e1922b33c81ce695d0edb91eb931"}, + {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c114e8da9b475739dde229fd3bc6b05a6537a88a578358bc8eb29b4030fac9c9"}, + {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e063b1865974611313a3849d43f2c3f5368093691349cf3c7c8f8f75ad7cb280"}, + {file = "websockets-11.0.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:92b2065d642bf8c0a82d59e59053dd2fdde64d4ed44efe4870fa816c1232647b"}, + {file = "websockets-11.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0ee68fe502f9031f19d495dae2c268830df2760c0524cbac5d759921ba8c8e82"}, + {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcacf2c7a6c3a84e720d1bb2b543c675bf6c40e460300b628bab1b1efc7c034c"}, + {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b67c6f5e5a401fc56394f191f00f9b3811fe843ee93f4a70df3c389d1adf857d"}, + {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d5023a4b6a5b183dc838808087033ec5df77580485fc533e7dab2567851b0a4"}, + {file = "websockets-11.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ed058398f55163a79bb9f06a90ef9ccc063b204bb346c4de78efc5d15abfe602"}, + {file = "websockets-11.0.3-py3-none-any.whl", hash = "sha256:6681ba9e7f8f3b19440921e99efbb40fc89f26cd71bf539e45d8c8a25c976dc6"}, + {file = "websockets-11.0.3.tar.gz", hash = "sha256:88fc51d9a26b10fc331be344f1781224a375b78488fc343620184e95a4b27016"}, ] -[package.extras] -docs = ["Sphinx (>=6.0)", "myst-parser (>=2.0.0)", "sphinx-rtd-theme (>=1.1.0)"] -optional = ["python-socks", "wsaccel"] -test = ["websockets"] - [[package]] name = "websockets" version = "13.1" @@ -1289,4 +1336,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "9a538c47996a060fe0abd622864b4e417052740cf83499a57c07c8af99c5e06c" +content-hash = "3e65d17371ab89d36348d4d4201e5ad68a1db0030290a58ca0cedc07a444c764" diff --git a/pyproject.toml b/pyproject.toml index 196d6a3..af6d4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ distro = "^1.9.0" websockets = [ { version = "^14.1.0", python = ">=3.9" }, { version = "^13.1.0", python = ">=3.8,<3.9" }, + { version = "^11.0.3", python = ">=3.7,<3.8" }, ] [tool.poetry.group.dev.dependencies] diff --git a/tests/test_audio_translations.py b/tests/test_audio_translations.py index ef29605..b494cb5 100644 --- a/tests/test_audio_translations.py +++ b/tests/test_audio_translations.py @@ -6,7 +6,7 @@ from tests.test_util import logid_key -def mock_create_translation(respx_mock): +def mock_create_transcriptions(respx_mock): logid = random_hex(10) raw_response = httpx.Response( 200, @@ -25,11 +25,11 @@ def mock_create_translation(respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") -class TestAudioTranslation: - def test_sync_translation_create(self, respx_mock): +class TestSyncAudioTranscriptions: + def test_sync_transcriptions_create(self, respx_mock): coze = Coze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = coze.audio.transcriptions.create(file=("filename", "content")) assert res @@ -39,11 +39,11 @@ def test_sync_translation_create(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio -class TestAsyncAudioTranslation: - async def test_async_translation_create(self, respx_mock): +class TestAsyncAudioTranscriptions: + async def test_async_transcriptions_create(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) - mock_logid = mock_create_translation(respx_mock) + mock_logid = mock_create_transcriptions(respx_mock) res = await coze.audio.transcriptions.create(file=("filename", "content")) assert res diff --git a/tests/test_conversations_messages.py b/tests/test_conversations_messages.py index 519579b..2922255 100644 --- a/tests/test_conversations_messages.py +++ b/tests/test_conversations_messages.py @@ -217,7 +217,6 @@ async def test_async_conversations_messages_retrieve(self, respx_mock): async def test_async_conversations_messages_update(self, respx_mock): coze = AsyncCoze(auth=TokenAuth(token="token")) - mock_msg = mock_update_conversations_messages(respx_mock, Message.build_user_question_text("hi")) message = await coze.conversations.messages.update(conversation_id="conversation id", message_id="message id")