Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose client manager args #284

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ If you don't want this, set `--no_single_model`.
- `use_vad`: Whether to use `Voice Activity Detection` on the server.
- `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
- `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
- `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4.
- `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600.

```python
from whisper_live.client import TranscriptionClient
client = TranscriptionClient(
Expand All @@ -87,7 +90,9 @@ client = TranscriptionClient(
model="small",
use_vad=False,
save_output_recording=True, # Only used for microphone input, False by Default
output_recording_filename="./output_recording.wav" # Only used for microphone input
output_recording_filename="./output_recording.wav", # Only used for microphone input
max_clients=4,
max_connection_time=600
)
```
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
Expand Down
2 changes: 1 addition & 1 deletion requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ jiwer
evaluate
numpy<2
tiktoken==0.3.3
openai-whisper
openai-whisper==20231117
4 changes: 3 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_on_open(self):
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
"use_vad": True
"use_vad": True,
"max_clients": 4,
"max_connection_time": 600,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)
Expand Down
26 changes: 12 additions & 14 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from unittest import mock

import numpy as np
import evaluate
import jiwer

from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.assertEqual(server.client_manager.max_clients, 4)
self.assertEqual(server.client_manager.max_connection_time, 600)
self.assertDictEqual(server.client_manager.clients, {})
Expand All @@ -25,6 +26,7 @@ def test_initialization(self):
class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.server.client_manager.start_times = {
'client1': time.time() - 120,
'client2': time.time() - 300
Expand All @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket):
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
Expand All @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket):
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

self.assertNotIn(mock_websocket, self.server.client_manager.clients)

Expand All @@ -82,7 +84,6 @@ def tearDownClass(cls):
cls.server_process.wait()

def setUp(self):
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()

def check_prediction(self, srt_path):
Expand All @@ -94,11 +95,8 @@ def check_prediction(self, srt_path):
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
self.assertLess(wer_score, 0.05)

def test_inference(self):
client = TranscriptionClient(
Expand All @@ -124,26 +122,26 @@ def setUp(self):

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection_closed_exception(self, mock_websocket):
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed")
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())

with self.assertLogs(level="INFO") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Connection closed by client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_json_decode_exception(self, mock_websocket):
mock_websocket.recv.return_value = "invalid json"

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_unexpected_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
for message in log.output:
print(message)
print()
Expand Down
19 changes: 16 additions & 3 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(
model="small",
srt_file_path="output.srt",
use_vad=True,
log_transcription=True
log_transcription=True,
max_clients=4,
max_connection_time=600,
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -59,6 +61,8 @@ def __init__(
self.last_segment = None
self.last_received_segment = None
self.log_transcription = log_transcription
self.max_clients = max_clients
self.max_connection_time = max_connection_time

if translate:
self.task = "translate"
Expand Down Expand Up @@ -199,7 +203,9 @@ def on_open(self, ws):
"language": self.language,
"task": self.task,
"model": self.model,
"use_vad": self.use_vad
"use_vad": self.use_vad,
"max_clients": self.max_clients,
"max_connection_time": self.max_connection_time,
}
)
)
Expand Down Expand Up @@ -681,8 +687,15 @@ def __init__(
output_recording_filename="./output_recording.wav",
output_transcription_path="./output.srt",
log_transcription=True,
max_clients=4,
max_connection_time=600,
):
self.client = Client(host, port, lang, translate, model, srt_file_path=output_transcription_path, use_vad=use_vad, log_transcription=log_transcription)
self.client = Client(
host, port, lang, translate, model, srt_file_path=output_transcription_path,
use_vad=use_vad, log_transcription=log_transcription, max_clients=max_clients,
max_connection_time=max_connection_time
)

if save_output_recording and not output_recording_filename.endswith(".wav"):
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
if not output_transcription_path.endswith(".srt"):
Expand Down
8 changes: 7 additions & 1 deletion whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TranscriptionServer:
RATE = 16000

def __init__(self):
self.client_manager = ClientManager()
self.client_manager = None
self.no_voice_activity_chunks = 0
self.use_vad = True
self.single_model = False
Expand Down Expand Up @@ -224,6 +224,12 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
logging.info("New client connected")
options = websocket.recv()
options = json.loads(options)

if self.client_manager is None:
max_clients = options.get('max_clients', 4)
max_connection_time = options.get('max_connection_time', 600)
self.client_manager = ClientManager(max_clients, max_connection_time)

self.use_vad = options.get('use_vad')
if self.client_manager.is_server_full(websocket, options):
websocket.close()
Expand Down
Loading