Skip to content

Commit

Permalink
Fixed wrong ASR_MODEL when using whisperx
Browse files Browse the repository at this point in the history
  • Loading branch information
Dennis Döring committed Oct 29, 2023
1 parent d545e5e commit 6256b61
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions app/mbain_whisperx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
from io import StringIO
from threading import Lock
import torch
import whisper
import whisperx
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
import whisper
from whisperx.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON

model_name= os.getenv("ASR_MODEL", "base")
hf_token= os.getenv("HF_TOKEN", "")
x_models = dict()

if torch.cuda.is_available():
device = "cuda"
model = whisper.load_model(model_name).cuda()
model = whisperx.load_model(model_name, device=device)
if hf_token != "":
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
else:
device = "cpu"
model = whisper.load_model(model_name)
model = whisperx.load_model(model_name, device=device)
if hf_token != "":
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
model_lock = Lock()
Expand Down

0 comments on commit 6256b61

Please sign in to comment.