Skip to content

Commit

Permalink
Unify a no speech probability handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
sobomax committed Jan 6, 2025
1 parent 019c7f2 commit 8968c4e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Apps/LiveTranslator/LTSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def text_in(self, result:STTResult):
sdir = 'A->B' if result.direction == 0 else 'B->A'
print(f'STT: {sdir} "{result.text=}" {result.no_speech_prob=}')
nsp = result.no_speech_prob
if nsp > 0.5: return
if nsp > STTRequest.max_ns_prob: return
sinfo = self.fabric.info[result.direction]
text = sinfo.translator(result.text)
speaker_id = sinfo.get_speaker()
Expand Down
11 changes: 8 additions & 3 deletions Cluster/InfernSTTWorker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, device: str, model_name: str = "openai/whisper-large-v3"):
self.device = device
self.infer_and_decode = partial(self.infer_and_decode_ct2 if device != 'xpu' else self.infer_and_decode_torch)

def infer_and_decode_ct2(self, prompts, inputs):
def infer_and_decode_ct2(self, prompts, inputs, max_nsps):
input_features = inputs.input_features
features = ctranslate2.StorageView.from_array(input_features)
try:
Expand All @@ -73,7 +73,7 @@ def infer_and_decode_ct2(self, prompts, inputs):
for r in results)
return decoded_results

def infer_and_decode_torch(self, prompts, inputs):
def infer_and_decode_torch(self, prompts, inputs, max_nsps):
inputs = {k: v.to(self.device) for k, v in inputs.items()}
max_len = max(len(t) for t in prompts)
prompts = torch.stack([
Expand All @@ -87,6 +87,9 @@ def infer_and_decode_torch(self, prompts, inputs):
)
logprobs = forward_outputs.logits[:, 0].log_softmax(-1)
no_speech_probs = logprobs[:, self.no_speech_token_id].exp().tolist()
if all(nsp > max_nsp for nsp, max_nsp in zip(no_speech_probs, max_nsps)):
return (('', nsp, []) for nsp in no_speech_probs)
with torch.no_grad():
gen_outputs = self.model.generate(
**inputs,
decoder_input_ids=prompts,
Expand All @@ -104,10 +107,12 @@ def infer_and_decode_torch(self, prompts, inputs):
def process_batch(self, wis:List[Tuple[STTRequest, List[int]]]):
if self.debug:
print(f'InfernSTTWorker.process_batch: got {len(wis)=}')
assert all(wi[0].chunk.samplerate == self.sample_rate for wi in wis)
audios = [wi[0].chunk.audio for wi in wis]
inputs = self.process_audios(audios, sampling_rate=self.sample_rate)
prompts = self.get_prompt(tuple((wi[0].lang, wi[0].mode, wi[0].timestamps) for wi in wis))
good_results = self.infer_and_decode(prompts, inputs)
max_nsps = [wi[0].max_ns_prob for wi in wis]
good_results = self.infer_and_decode(prompts, inputs, max_nsps)
for (wi, c), (r, nsp, t) in zip(wis, good_results):
# Remove leading and trailing space: "WhitespaceTokenizer adds a space at the beginning?" (copilot)
if len(r) > 0 and r[0] == ' ': r = r[1:]
Expand Down
5 changes: 2 additions & 3 deletions Cluster/STTSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from threading import Lock
from time import monotonic

import torch

from Core.AudioChunk import AudioChunk

class STTRequest():
Expand All @@ -15,7 +13,8 @@ class STTRequest():
text_cb: callable
mode: str = 'transcribe'
timestamps: bool = False
stime:float
stime: float
max_ns_prob: float = 0.5
def __init__(self, chunk:AudioChunk, text_cb:callable, lang:str):
self.stime = monotonic()
self.lang, self.chunk, self.text_cb = lang, chunk, text_cb
Expand Down

0 comments on commit 8968c4e

Please sign in to comment.