diff --git a/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py b/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py index 5ef66f58d..9ee7a8947 100644 --- a/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py +++ b/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py @@ -160,9 +160,16 @@ async def input_generator(): except Exception as e: logger.exception(f"an error occurred while streaming inputs: {e}") - # try to connect handler = TranscriptEventHandler(stream.output_stream, self._event_ch) - await asyncio.gather(input_generator(), handler.handle_events()) + tasks = [ + asyncio.create_task(input_generator()), + asyncio.create_task(handler.handle_events()), + ] + # try to connect + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) except Exception as e: logger.exception(f"an error occurred while streaming inputs: {e}") diff --git a/tests/test_stt.py b/tests/test_stt.py index e3b17100d..0d7a3f4ca 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -109,9 +109,9 @@ async def _stream_output(): async for event in stream: if event.type == agents.stt.SpeechEventType.START_OF_SPEECH: - assert ( - recv_end - ), "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + assert recv_end, ( + "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + ) assert not recv_start recv_end = False recv_start = True diff --git a/tests/utils.py b/tests/utils.py index ad20fc43a..dc7933de4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,18 +20,16 @@ def wer(hypothesis: str, reference: str) -> float: - wer_standardize_contiguous = tr.Compose( - [ - tr.ToLowerCase(), - tr.ExpandCommonEnglishContractions(), - tr.RemoveKaldiNonWords(), - tr.RemoveWhiteSpace(replace_by_space=True), - tr.RemoveMultipleSpaces(), - tr.Strip(), - tr.ReduceToSingleSentence(), - tr.ReduceToListOfListOfWords(), - ] - ) + wer_standardize_contiguous = tr.Compose([ + tr.ToLowerCase(), + tr.ExpandCommonEnglishContractions(), + tr.RemoveKaldiNonWords(), + tr.RemoveWhiteSpace(replace_by_space=True), + tr.RemoveMultipleSpaces(), + tr.Strip(), + tr.ReduceToSingleSentence(), + tr.ReduceToListOfListOfWords(), + ]) return tr.wer( reference,