From 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:52:47 +0100 Subject: [PATCH] [Whisper] patch float type on mps (#35295) * fix float type on mps * make --- .../models/whisper/generation_whisper.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 2f58375f3de751..fdaeff14d78867 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -632,7 +632,9 @@ def generate( cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, ) - time_offset = seek.to(torch.float64) * time_precision / input_stride + time_offset = ( + seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride + ) seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) # 6.2 cut out next 30s segment from input features @@ -1805,6 +1807,7 @@ def _retrieve_segment( timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices.add_(1) token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else [] + device = seek_sequence.device # If whisper predicted a "end of segment" via a timestep token, let's go ever each # "end of segment" prediction and slice the decoding into segments accordingly @@ -1828,8 +1831,12 @@ def _retrieve_segment( end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin segments.append( { - "start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision, - "end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision, + "start": time_offset[prev_idx] + + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64) + * time_precision, + "end": time_offset[prev_idx] + + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64) + * time_precision, "tokens": sliced_tokens, "result": seek_outputs[idx], } @@ -1856,7 +1863,9 @@ def _retrieve_segment( last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64) + last_timestamp_pos = (timestamps[-1] - timestamp_begin).to( + torch.float32 if device.type == "mps" else torch.float64 + ) segments = [ { "start": time_offset[prev_idx],