From 8b87a0562d6fc8ee4c8d701ce89b340a01a25a40 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Mon, 28 Oct 2024 17:01:06 +0530 Subject: [PATCH] Fix unittest to exposed client manager args Signed-off-by: makaveli10 --- tests/test_client.py | 4 +++- tests/test_server.py | 26 ++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 66189e5c..836b616b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,7 +48,9 @@ def test_on_open(self): "language": self.client.language, "task": self.client.task, "model": self.client.model, - "use_vad": True + "use_vad": True, + "max_clients": 4, + "max_connection_time": 600, }) self.client.on_open(self.mock_ws_app) self.mock_ws_app.send.assert_called_with(expected_message) diff --git a/tests/test_server.py b/tests/test_server.py index f836be70..de787029 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,10 +5,10 @@ from unittest import mock import numpy as np -import evaluate +import jiwer from websockets.exceptions import ConnectionClosed -from whisper_live.server import TranscriptionServer +from whisper_live.server import TranscriptionServer, BackendType, ClientManager from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient from whisper.normalizers import EnglishTextNormalizer @@ -16,6 +16,7 @@ class TestTranscriptionServerInitialization(unittest.TestCase): def test_initialization(self): server = TranscriptionServer() + server.client_manager = ClientManager(max_clients=4, max_connection_time=600) self.assertEqual(server.client_manager.max_clients, 4) self.assertEqual(server.client_manager.max_connection_time, 600) self.assertDictEqual(server.client_manager.clients, {}) @@ -25,6 +26,7 @@ def test_initialization(self): class TestGetWaitTime(unittest.TestCase): def setUp(self): self.server = TranscriptionServer() + self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600) self.server.client_manager.start_times = { 'client1': time.time() - 120, 'client2': time.time() - 300 @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket): 'task': 'transcribe', 'model': 'tiny.en' }) - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) @mock.patch('websockets.WebSocketCommonProtocol') def test_recv_audio_exception_handling(self, mock_websocket): @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket): }), np.array([1, 2, 3]).tobytes()] with self.assertLogs(level="ERROR"): - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertNotIn(mock_websocket, self.server.client_manager.clients) @@ -82,7 +84,6 @@ def tearDownClass(cls): cls.server_process.wait() def setUp(self): - self.metric = evaluate.load("wer") self.normalizer = EnglishTextNormalizer() def check_prediction(self, srt_path): @@ -94,11 +95,8 @@ def check_prediction(self, srt_path): gt_normalized = self.normalizer(gt) # calculate WER - wer = self.metric.compute( - predictions=[prediction_normalized], - references=[gt_normalized] - ) - self.assertLess(wer, 0.05) + wer_score = jiwer.wer(gt_normalized, prediction_normalized) + self.assertLess(wer_score, 0.05) def test_inference(self): client = TranscriptionClient( @@ -124,10 +122,10 @@ def setUp(self): @mock.patch('websockets.WebSocketCommonProtocol') def test_connection_closed_exception(self, mock_websocket): - mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed") + mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock()) with self.assertLogs(level="INFO") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertTrue(any("Connection closed by client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -135,7 +133,7 @@ def test_json_decode_exception(self, mock_websocket): mock_websocket.recv.return_value = "invalid json" with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -143,7 +141,7 @@ def test_unexpected_exception_handling(self, mock_websocket): mock_websocket.recv.side_effect = RuntimeError("Unexpected error") with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType("faster_whisper")) for message in log.output: print(message) print()