Skip to content

Commit

Permalink
Fix flake8 errors
Browse files Browse the repository at this point in the history
  • Loading branch information
turicas committed Nov 12, 2023
1 parent 714ea1e commit 7358507
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,24 @@ def __init__(
"""Initializes the Whisper model.
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, medium,
medium.en, large-v1, large-v2 (same as large), or large-v3), a path to a converted model directory, or a
CTranslate2-converted Whisper model ID from the Hugging Face Hub. When a size or a model ID is configured,
the converted model is downloaded from the Hugging Face Hub.
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, small,
small.en, medium, medium.en, large-v1, large-v2 (same as large), or large-v3), a path
to a converted model directory, or a CTranslate2-converted Whisper model ID from the
Hugging Face Hub. When a size or a model ID is configured, the converted model is
downloaded from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2, 3]). In that case,
multiple transcriptions can run in parallel when transcribe() is called from multiple Python threads (see
also num_workers).
The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2,
3]). In that case, multiple transcriptions can run in parallel when transcribe() is
called from multiple Python threads (see also num_workers).
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
A non zero value overrides the OMP_NUM_THREADS environment variable.
num_workers: When transcribe() is called from multiple Python threads,
having multiple workers enables true parallelism when running the model (concurrent calls to
self.model.generate() will run in parallel). This can improve the global throughput at the cost of
increased memory usage.
having multiple workers enables true parallelism when running the model (concurrent
calls to self.model.generate() will run in parallel). This can improve the global
throughput at the cost of increased memory usage.
download_root: Directory where the models should be saved. If not set, the models
are saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the
Expand All @@ -132,25 +133,33 @@ def __init__(
intra_threads=cpu_threads,
inter_threads=num_workers,
)
is_large_v3 = "large-v3" in model_size_or_path # TODO: check by inspecting `self.model` instead?
# TODO: check if it's large-v3 by inspecting `self.model` instead?
is_large_v3 = "large-v3" in model_size_or_path

tokenizer_file = os.path.join(model_path, "tokenizer.json")
if os.path.isfile(tokenizer_file):
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else:
if is_large_v3:
# Tokenizer for large-v3 is different, so we need to load as in whisper-large-v3 and monkey patch it to
# have the `token_to_id` method.
# Tokenizer for large-v3 is different, so we need to load as in whisper-large-v3
# and monkey patch it to have the `token_to_id` method.
# TODO: load the new tokenizer without requiring `transformers`
from transformers import AutoProcessor

self.hf_tokenizer = AutoProcessor.from_pretrained("openai/whisper-large-v3").tokenizer
self.hf_tokenizer.token_to_id = lambda token: self.hf_tokenizer.convert_tokens_to_ids(token)
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
self.hf_tokenizer = processor.tokenizer
self.hf_tokenizer.token_to_id = (
lambda token: self.hf_tokenizer.convert_tokens_to_ids(token)
)
else:
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
"openai/whisper-tiny"
+ ("" if self.model.is_multilingual else ".en")
)

self.feature_extractor = FeatureExtractor(feature_size=80 if not is_large_v3 else 128)
self.feature_extractor = FeatureExtractor(
feature_size=80 if not is_large_v3 else 128
)
self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
Expand Down

0 comments on commit 7358507

Please sign in to comment.