From 4081e345a73895baabe427d10d8e2cd2cee026a3 Mon Sep 17 00:00:00 2001 From: Octavian Mot <64348499+octimot@users.noreply.github.com> Date: Thu, 4 May 2023 12:09:53 +0200 Subject: [PATCH] Fixed CPU unusable on CUDA machines --- app.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index cda5468..979400e 100644 --- a/app.py +++ b/app.py @@ -8811,6 +8811,12 @@ def whisper_device_select(self, device): :return: ''' + allowed_devices = ['cuda', 'CUDA', 'gpu', 'GPU', 'cpu', 'CPU'] + + # change the whisper device if it was passed as a parameter + if device is not None and device in allowed_devices: + self.whisper_device = device + # if the whisper device is set to cuda if self.whisper_device in ['cuda', 'CUDA', 'gpu', 'GPU']: # use CUDA if available @@ -10110,6 +10116,10 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None, # what is the name of the audio file audio_file_name = os.path.basename(audio_file_path) + + whisper_device_changed = False + if 'device' in other_whisper_options and self.whisper_device != other_whisper_options['device']: + whisper_device_changed = True # select the device that was passed (if any) if 'device' in other_whisper_options: @@ -10120,7 +10130,8 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None, # load OpenAI Whisper model # and hold it loaded for future use (unless another model was passed via other_whisper_options) if self.whisper_model is None \ - or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model']): + or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model'])\ + or whisper_device_changed: # update the status of the item in the transcription log self.update_transcription_log(unique_id=queue_id, **{'status': 'loading model'}) @@ -10147,7 +10158,7 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None, logger.info('Loading Whisper {} model.'.format(self.whisper_model_name)) try: - self.whisper_model = whisper.load_model(self.whisper_model_name) + self.whisper_model = whisper.load_model(self.whisper_model_name, device=self.whisper_device) except Exception as e: logger.error('Error loading Whisper {} model: {}'.format(self.whisper_model_name, e))