Skip to content

Commit

Permalink
Merge pull request #292 from makaveli10/fix_srt_file_missing_segments
Browse files Browse the repository at this point in the history
Fix srt file missing segments.
  • Loading branch information
zoq authored Nov 5, 2024
2 parents 00f0ff1 + 8d89de2 commit 0e89573
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
8 changes: 4 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def test_on_message(self):
message = json.dumps({
"uid": self.client.uid,
"segments": [
{"start": 0, "end": 1, "text": "Test transcript"},
{"start": 1, "end": 2, "text": "Test transcript 2"},
{"start": 2, "end": 3, "text": "Test transcript 3"}
{"start": 0, "end": 1, "text": "Test transcript", "completed": True},
{"start": 1, "end": 2, "text": "Test transcript 2", "completed": True},
{"start": 2, "end": 3, "text": "Test transcript 3", "completed": True}
]
})
self.client.on_message(self.mock_ws_app, message)

# Assert that the transcript was updated correctly
self.assertEqual(len(self.client.transcript), 2)
self.assertEqual(len(self.client.transcript), 3)
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")

def test_on_close(self):
Expand Down
6 changes: 3 additions & 3 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def process_segments(self, segments):
for i, seg in enumerate(segments):
if not text or text[-1] != seg["text"]:
text.append(seg["text"])
if i == len(segments) - 1:
if i == len(segments) - 1 and not seg["completed"]:
self.last_segment = seg
elif (self.server_backend == "faster_whisper" and
elif (self.server_backend == "faster_whisper" and seg["completed"] and
(not self.transcript or
float(seg['start']) >= float(self.transcript[-1]['end']))):
self.transcript.append(seg)
Expand Down Expand Up @@ -259,7 +259,7 @@ def write_srt_file(self, output_path="output.srt"):
"""
if self.server_backend == "faster_whisper":
if (self.last_segment):
if (self.last_segment) and self.transcript[-1]["text"] != self.last_segment["text"]:
self.transcript.append(self.last_segment)
utils.create_srt_file(self.transcript, output_path)

Expand Down
13 changes: 8 additions & 5 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def speech_to_text(self):
logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}")
time.sleep(0.01)

def format_segment(self, start, end, text):
def format_segment(self, start, end, text, completed=False):
"""
Formats a transcription segment with precise start and end times alongside the transcribed text.
Expand All @@ -1018,7 +1018,8 @@ def format_segment(self, start, end, text):
return {
'start': "{:.3f}".format(start),
'end': "{:.3f}".format(end),
'text': text
'text': text,
'completed': completed
}

def update_segments(self, segments, duration):
Expand Down Expand Up @@ -1058,7 +1059,7 @@ def update_segments(self, segments, duration):
if s.no_speech_prob > self.no_speech_thresh:
continue

self.transcript.append(self.format_segment(start, end, text_))
self.transcript.append(self.format_segment(start, end, text_, completed=True))
offset = min(duration, s.end)

# only process the segments if it satisfies the no_speech_thresh
Expand All @@ -1067,7 +1068,8 @@ def update_segments(self, segments, duration):
last_segment = self.format_segment(
self.timestamp_offset + segments[-1].start,
self.timestamp_offset + min(duration, segments[-1].end),
self.current_out
self.current_out,
completed=False
)

# if same incomplete segment is seen multiple times then update the offset
Expand All @@ -1083,7 +1085,8 @@ def update_segments(self, segments, duration):
self.transcript.append(self.format_segment(
self.timestamp_offset,
self.timestamp_offset + duration,
self.current_out
self.current_out,
completed=True
))
self.current_out = ''
offset = duration
Expand Down

0 comments on commit 0e89573

Please sign in to comment.