Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Spudra authored Nov 5, 2024
2 parents 83988ea + 0e89573 commit 5f270a1
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 32 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ If you don't want this, set `--no_single_model`.
- `use_vad`: Whether to use `Voice Activity Detection` on the server.
- `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
- `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
- `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4.
- `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600.

```python
from whisper_live.client import TranscriptionClient
client = TranscriptionClient(
Expand All @@ -87,10 +90,12 @@ client = TranscriptionClient(
model="small",
use_vad=False,
save_output_recording=True, # Only used for microphone input, False by Default
output_recording_filename="./output_recording.wav" # Only used for microphone input
output_recording_filename="./output_recording.wav", # Only used for microphone input
options={
'initial_prompt': None, #To add context replace None with any context for the model like this: 'Jane Doe context'
}
},
max_clients=4,
max_connection_time=600
)
```
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
Expand Down
2 changes: 1 addition & 1 deletion requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ jiwer
evaluate
numpy<2
tiktoken==0.3.3
openai-whisper
openai-whisper==20231117
12 changes: 7 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_on_open(self):
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
"use_vad": True
"use_vad": True,
"max_clients": 4,
"max_connection_time": 600,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)
Expand All @@ -66,15 +68,15 @@ def test_on_message(self):
message = json.dumps({
"uid": self.client.uid,
"segments": [
{"start": 0, "end": 1, "text": "Test transcript"},
{"start": 1, "end": 2, "text": "Test transcript 2"},
{"start": 2, "end": 3, "text": "Test transcript 3"}
{"start": 0, "end": 1, "text": "Test transcript", "completed": True},
{"start": 1, "end": 2, "text": "Test transcript 2", "completed": True},
{"start": 2, "end": 3, "text": "Test transcript 3", "completed": True}
]
})
self.client.on_message(self.mock_ws_app, message)

# Assert that the transcript was updated correctly
self.assertEqual(len(self.client.transcript), 2)
self.assertEqual(len(self.client.transcript), 3)
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")

def test_on_close(self):
Expand Down
26 changes: 12 additions & 14 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from unittest import mock

import numpy as np
import evaluate
import jiwer

from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.assertEqual(server.client_manager.max_clients, 4)
self.assertEqual(server.client_manager.max_connection_time, 600)
self.assertDictEqual(server.client_manager.clients, {})
Expand All @@ -25,6 +26,7 @@ def test_initialization(self):
class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.server.client_manager.start_times = {
'client1': time.time() - 120,
'client2': time.time() - 300
Expand All @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket):
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
Expand All @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket):
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

self.assertNotIn(mock_websocket, self.server.client_manager.clients)

Expand All @@ -82,7 +84,6 @@ def tearDownClass(cls):
cls.server_process.wait()

def setUp(self):
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()

def check_prediction(self, srt_path):
Expand All @@ -94,11 +95,8 @@ def check_prediction(self, srt_path):
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
self.assertLess(wer_score, 0.05)

def test_inference(self):
client = TranscriptionClient(
Expand All @@ -124,26 +122,26 @@ def setUp(self):

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection_closed_exception(self, mock_websocket):
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed")
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())

with self.assertLogs(level="INFO") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Connection closed by client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_json_decode_exception(self, mock_websocket):
mock_websocket.recv.return_value = "invalid json"

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_unexpected_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
for message in log.output:
print(message)
print()
Expand Down
22 changes: 18 additions & 4 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
use_vad=True,
log_transcription=True,
options=None
max_clients=4,
max_connection_time=600,
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -61,6 +63,9 @@ def __init__(
self.last_received_segment = None
self.log_transcription = log_transcription
self.options = options
self.max_clients = max_clients
self.max_connection_time = max_connection_time


if translate:
self.task = "translate"
Expand Down Expand Up @@ -110,9 +115,9 @@ def process_segments(self, segments):
for i, seg in enumerate(segments):
if not text or text[-1] != seg["text"]:
text.append(seg["text"])
if i == len(segments) - 1:
if i == len(segments) - 1 and not seg["completed"]:
self.last_segment = seg
elif (self.server_backend == "faster_whisper" and
elif (self.server_backend == "faster_whisper" and seg["completed"] and
(not self.transcript or
float(seg['start']) >= float(self.transcript[-1]['end']))):
self.transcript.append(seg)
Expand Down Expand Up @@ -203,6 +208,8 @@ def on_open(self, ws):
"model": self.model,
"use_vad": self.use_vad,
"options": self.options
"max_clients": self.max_clients,
"max_connection_time": self.max_connection_time,
}
)
)
Expand Down Expand Up @@ -256,7 +263,7 @@ def write_srt_file(self, output_path="output.srt"):
"""
if self.server_backend == "faster_whisper":
if (self.last_segment):
if (self.last_segment) and self.transcript[-1]["text"] != self.last_segment["text"]:
self.transcript.append(self.last_segment)
utils.create_srt_file(self.transcript, output_path)

Expand Down Expand Up @@ -685,8 +692,15 @@ def __init__(
output_transcription_path="./output.srt",
log_transcription=True,
options=None,
max_clients=4,
max_connection_time=600,
):
self.client = Client(host, port, lang, translate, model, srt_file_path=output_transcription_path, use_vad=use_vad, log_transcription=log_transcription, options=options)
self.client = Client(
host, port, lang, translate, model, srt_file_path=output_transcription_path,
use_vad=use_vad, log_transcription=log_transcription, options=options, max_clients=max_clients,
max_connection_time=max_connection_time
)

if save_output_recording and not output_recording_filename.endswith(".wav"):
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
if not output_transcription_path.endswith(".srt"):
Expand Down
21 changes: 15 additions & 6 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TranscriptionServer:
RATE = 16000

def __init__(self):
self.client_manager = ClientManager()
self.client_manager = None
self.no_voice_activity_chunks = 0
self.use_vad = True
self.single_model = False
Expand Down Expand Up @@ -224,6 +224,12 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
logging.info("New client connected")
options = websocket.recv()
options = json.loads(options)

if self.client_manager is None:
max_clients = options.get('max_clients', 4)
max_connection_time = options.get('max_connection_time', 600)
self.client_manager = ClientManager(max_clients, max_connection_time)

self.use_vad = options.get('use_vad')
if self.client_manager.is_server_full(websocket, options):
websocket.close()
Expand Down Expand Up @@ -995,7 +1001,7 @@ def speech_to_text(self):
logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}")
time.sleep(0.01)

def format_segment(self, start, end, text):
def format_segment(self, start, end, text, completed=False):
"""
Formats a transcription segment with precise start and end times alongside the transcribed text.
Expand All @@ -1012,7 +1018,8 @@ def format_segment(self, start, end, text):
return {
'start': "{:.3f}".format(start),
'end': "{:.3f}".format(end),
'text': text
'text': text,
'completed': completed
}

def update_segments(self, segments, duration):
Expand Down Expand Up @@ -1052,7 +1059,7 @@ def update_segments(self, segments, duration):
if s.no_speech_prob > self.no_speech_thresh:
continue

self.transcript.append(self.format_segment(start, end, text_))
self.transcript.append(self.format_segment(start, end, text_, completed=True))
offset = min(duration, s.end)

# only process the segments if it satisfies the no_speech_thresh
Expand All @@ -1061,7 +1068,8 @@ def update_segments(self, segments, duration):
last_segment = self.format_segment(
self.timestamp_offset + segments[-1].start,
self.timestamp_offset + min(duration, segments[-1].end),
self.current_out
self.current_out,
completed=False
)

# if same incomplete segment is seen multiple times then update the offset
Expand All @@ -1077,7 +1085,8 @@ def update_segments(self, segments, duration):
self.transcript.append(self.format_segment(
self.timestamp_offset,
self.timestamp_offset + duration,
self.current_out
self.current_out,
completed=True
))
self.current_out = ''
offset = duration
Expand Down

0 comments on commit 5f270a1

Please sign in to comment.