Skip to content

Commit

Permalink
Fix unittest to exposed client manager args
Browse files Browse the repository at this point in the history
Signed-off-by: makaveli10 <[email protected]>
  • Loading branch information
makaveli10 committed Oct 28, 2024
1 parent 617fda2 commit 8b87a05
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
4 changes: 3 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_on_open(self):
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
"use_vad": True
"use_vad": True,
"max_clients": 4,
"max_connection_time": 600,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)
Expand Down
26 changes: 12 additions & 14 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from unittest import mock

import numpy as np
import evaluate
import jiwer

from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.assertEqual(server.client_manager.max_clients, 4)
self.assertEqual(server.client_manager.max_connection_time, 600)
self.assertDictEqual(server.client_manager.clients, {})
Expand All @@ -25,6 +26,7 @@ def test_initialization(self):
class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.server.client_manager.start_times = {
'client1': time.time() - 120,
'client2': time.time() - 300
Expand All @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket):
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
Expand All @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket):
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

self.assertNotIn(mock_websocket, self.server.client_manager.clients)

Expand All @@ -82,7 +84,6 @@ def tearDownClass(cls):
cls.server_process.wait()

def setUp(self):
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()

def check_prediction(self, srt_path):
Expand All @@ -94,11 +95,8 @@ def check_prediction(self, srt_path):
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
self.assertLess(wer_score, 0.05)

def test_inference(self):
client = TranscriptionClient(
Expand All @@ -124,26 +122,26 @@ def setUp(self):

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection_closed_exception(self, mock_websocket):
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed")
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())

with self.assertLogs(level="INFO") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Connection closed by client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_json_decode_exception(self, mock_websocket):
mock_websocket.recv.return_value = "invalid json"

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_unexpected_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
for message in log.output:
print(message)
print()
Expand Down

0 comments on commit 8b87a05

Please sign in to comment.