From 1322dd3c2748946a4b0da07060083bc976c82e18 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Thu, 10 Oct 2024 08:36:21 -0400 Subject: [PATCH 1/6] Pin openai-whisper version Signed-off-by: makaveli10 --- requirements/server.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/server.txt b/requirements/server.txt index f1f873c..402f967 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -10,4 +10,4 @@ jiwer evaluate numpy<2 tiktoken==0.3.3 -openai-whisper \ No newline at end of file +openai-whisper==20231117 \ No newline at end of file From 0d74790c670e95f461db45374b9bc1cd8c8f8e81 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Thu, 10 Oct 2024 08:37:45 -0400 Subject: [PATCH 2/6] Expose ClientManager arguments to be passed from client Signed-off-by: makaveli10 --- whisper_live/client.py | 19 ++++++++++++++++--- whisper_live/server.py | 8 +++++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index c252607..4cfb63c 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -30,7 +30,9 @@ def __init__( model="small", srt_file_path="output.srt", use_vad=True, - log_transcription=True + log_transcription=True, + max_clients=4, + max_connection_time=600, ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -59,6 +61,8 @@ def __init__( self.last_segment = None self.last_received_segment = None self.log_transcription = log_transcription + self.max_clients = max_clients + self.max_connection_time = max_connection_time if translate: self.task = "translate" @@ -199,7 +203,9 @@ def on_open(self, ws): "language": self.language, "task": self.task, "model": self.model, - "use_vad": self.use_vad + "use_vad": self.use_vad, + "max_clients": self.max_clients, + "max_connection_time": self.max_connection_time, } ) ) @@ -681,8 +687,15 @@ def __init__( output_recording_filename="./output_recording.wav", output_transcription_path="./output.srt", log_transcription=True, + max_clients=4, + max_connection_time=600, ): - self.client = Client(host, port, lang, translate, model, srt_file_path=output_transcription_path, use_vad=use_vad, log_transcription=log_transcription) + self.client = Client( + host, port, lang, translate, model, srt_file_path=output_transcription_path, + use_vad=use_vad, log_transcription=log_transcription, max_clients=max_clients, + max_connection_time=max_connection_time + ) + if save_output_recording and not output_recording_filename.endswith(".wav"): raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}") if not output_transcription_path.endswith(".srt"): diff --git a/whisper_live/server.py b/whisper_live/server.py index e3346d2..575a2cf 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -147,7 +147,7 @@ class TranscriptionServer: RATE = 16000 def __init__(self): - self.client_manager = ClientManager() + self.client_manager = None self.no_voice_activity_chunks = 0 self.use_vad = True self.single_model = False @@ -224,6 +224,12 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path, logging.info("New client connected") options = websocket.recv() options = json.loads(options) + + if self.client_manager is None: + max_clients = options.get('max_clients', 4) + max_connection_time = options.get('max_connection_time', 600) + self.client_manager = ClientManager(max_clients, max_connection_time) + self.use_vad = options.get('use_vad') if self.client_manager.is_server_full(websocket, options): websocket.close() From 617fda2864298e2f28c8422e94bfaab85b1bf43c Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Thu, 10 Oct 2024 08:43:21 -0400 Subject: [PATCH 3/6] Update Readme Signed-off-by: makaveli10 --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f98b9b..19bb230 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,9 @@ If you don't want this, set `--no_single_model`. - `use_vad`: Whether to use `Voice Activity Detection` on the server. - `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`. - `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`. + - `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4. + - `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600. + ```python from whisper_live.client import TranscriptionClient client = TranscriptionClient( @@ -87,7 +90,9 @@ client = TranscriptionClient( model="small", use_vad=False, save_output_recording=True, # Only used for microphone input, False by Default - output_recording_filename="./output_recording.wav" # Only used for microphone input + output_recording_filename="./output_recording.wav", # Only used for microphone input + max_clients=4, + max_connection_time=600 ) ``` It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language. From 8b87a0562d6fc8ee4c8d701ce89b340a01a25a40 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Mon, 28 Oct 2024 17:01:06 +0530 Subject: [PATCH 4/6] Fix unittest to exposed client manager args Signed-off-by: makaveli10 --- tests/test_client.py | 4 +++- tests/test_server.py | 26 ++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 66189e5..836b616 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,7 +48,9 @@ def test_on_open(self): "language": self.client.language, "task": self.client.task, "model": self.client.model, - "use_vad": True + "use_vad": True, + "max_clients": 4, + "max_connection_time": 600, }) self.client.on_open(self.mock_ws_app) self.mock_ws_app.send.assert_called_with(expected_message) diff --git a/tests/test_server.py b/tests/test_server.py index f836be7..de78702 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,10 +5,10 @@ from unittest import mock import numpy as np -import evaluate +import jiwer from websockets.exceptions import ConnectionClosed -from whisper_live.server import TranscriptionServer +from whisper_live.server import TranscriptionServer, BackendType, ClientManager from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient from whisper.normalizers import EnglishTextNormalizer @@ -16,6 +16,7 @@ class TestTranscriptionServerInitialization(unittest.TestCase): def test_initialization(self): server = TranscriptionServer() + server.client_manager = ClientManager(max_clients=4, max_connection_time=600) self.assertEqual(server.client_manager.max_clients, 4) self.assertEqual(server.client_manager.max_connection_time, 600) self.assertDictEqual(server.client_manager.clients, {}) @@ -25,6 +26,7 @@ def test_initialization(self): class TestGetWaitTime(unittest.TestCase): def setUp(self): self.server = TranscriptionServer() + self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600) self.server.client_manager.start_times = { 'client1': time.time() - 120, 'client2': time.time() - 300 @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket): 'task': 'transcribe', 'model': 'tiny.en' }) - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) @mock.patch('websockets.WebSocketCommonProtocol') def test_recv_audio_exception_handling(self, mock_websocket): @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket): }), np.array([1, 2, 3]).tobytes()] with self.assertLogs(level="ERROR"): - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertNotIn(mock_websocket, self.server.client_manager.clients) @@ -82,7 +84,6 @@ def tearDownClass(cls): cls.server_process.wait() def setUp(self): - self.metric = evaluate.load("wer") self.normalizer = EnglishTextNormalizer() def check_prediction(self, srt_path): @@ -94,11 +95,8 @@ def check_prediction(self, srt_path): gt_normalized = self.normalizer(gt) # calculate WER - wer = self.metric.compute( - predictions=[prediction_normalized], - references=[gt_normalized] - ) - self.assertLess(wer, 0.05) + wer_score = jiwer.wer(gt_normalized, prediction_normalized) + self.assertLess(wer_score, 0.05) def test_inference(self): client = TranscriptionClient( @@ -124,10 +122,10 @@ def setUp(self): @mock.patch('websockets.WebSocketCommonProtocol') def test_connection_closed_exception(self, mock_websocket): - mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed") + mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock()) with self.assertLogs(level="INFO") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertTrue(any("Connection closed by client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -135,7 +133,7 @@ def test_json_decode_exception(self, mock_websocket): mock_websocket.recv.return_value = "invalid json" with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -143,7 +141,7 @@ def test_unexpected_exception_handling(self, mock_websocket): mock_websocket.recv.side_effect = RuntimeError("Unexpected error") with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) for message in log.output: print(message) print() From 81c57ae40c6b2f63b18c6ad3ee5d448208752296 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Tue, 5 Nov 2024 18:11:32 +0530 Subject: [PATCH 5/6] Send completed bool with each segment Completed bool represents if the segment is completely processed by the server Signed-off-by: makaveli10 --- whisper_live/client.py | 6 +++--- whisper_live/server.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 4cfb63c..15b6306 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -112,9 +112,9 @@ def process_segments(self, segments): for i, seg in enumerate(segments): if not text or text[-1] != seg["text"]: text.append(seg["text"]) - if i == len(segments) - 1: + if i == len(segments) - 1 and not seg["completed"]: self.last_segment = seg - elif (self.server_backend == "faster_whisper" and + elif (self.server_backend == "faster_whisper" and seg["completed"] and (not self.transcript or float(seg['start']) >= float(self.transcript[-1]['end']))): self.transcript.append(seg) @@ -259,7 +259,7 @@ def write_srt_file(self, output_path="output.srt"): """ if self.server_backend == "faster_whisper": - if (self.last_segment): + if (self.last_segment) and self.transcript[-1]["text"] != self.last_segment["text"]: self.transcript.append(self.last_segment) utils.create_srt_file(self.transcript, output_path) diff --git a/whisper_live/server.py b/whisper_live/server.py index 575a2cf..b68df5a 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -1001,7 +1001,7 @@ def speech_to_text(self): logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}") time.sleep(0.01) - def format_segment(self, start, end, text): + def format_segment(self, start, end, text, completed=False): """ Formats a transcription segment with precise start and end times alongside the transcribed text. @@ -1018,7 +1018,8 @@ def format_segment(self, start, end, text): return { 'start': "{:.3f}".format(start), 'end': "{:.3f}".format(end), - 'text': text + 'text': text, + 'completed': completed } def update_segments(self, segments, duration): @@ -1058,7 +1059,7 @@ def update_segments(self, segments, duration): if s.no_speech_prob > self.no_speech_thresh: continue - self.transcript.append(self.format_segment(start, end, text_)) + self.transcript.append(self.format_segment(start, end, text_, completed=True)) offset = min(duration, s.end) # only process the segments if it satisfies the no_speech_thresh @@ -1067,7 +1068,8 @@ def update_segments(self, segments, duration): last_segment = self.format_segment( self.timestamp_offset + segments[-1].start, self.timestamp_offset + min(duration, segments[-1].end), - self.current_out + self.current_out, + completed=False ) # if same incomplete segment is seen multiple times then update the offset @@ -1083,7 +1085,8 @@ def update_segments(self, segments, duration): self.transcript.append(self.format_segment( self.timestamp_offset, self.timestamp_offset + duration, - self.current_out + self.current_out, + completed=True )) self.current_out = '' offset = duration From 8d89de22d817dec6a0eac38f6fdce331c9244de2 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Tue, 5 Nov 2024 18:12:09 +0530 Subject: [PATCH 6/6] Update tests to incorporate the completed boolean in segments Signed-off-by: makaveli10 --- tests/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 836b616..4610ea9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -68,15 +68,15 @@ def test_on_message(self): message = json.dumps({ "uid": self.client.uid, "segments": [ - {"start": 0, "end": 1, "text": "Test transcript"}, - {"start": 1, "end": 2, "text": "Test transcript 2"}, - {"start": 2, "end": 3, "text": "Test transcript 3"} + {"start": 0, "end": 1, "text": "Test transcript", "completed": True}, + {"start": 1, "end": 2, "text": "Test transcript 2", "completed": True}, + {"start": 2, "end": 3, "text": "Test transcript 3", "completed": True} ] }) self.client.on_message(self.mock_ws_app, message) # Assert that the transcript was updated correctly - self.assertEqual(len(self.client.transcript), 2) + self.assertEqual(len(self.client.transcript), 3) self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") def test_on_close(self):