Skip to content

Commit

Permalink
Adding multi gpu speech generation (#3149)
Browse files Browse the repository at this point in the history
* skeleton code

* fix some errors for downloading the model

* fix some tqdm error

* fix some error

* fix some gpu errors with torch

* fix some gpu errors with torch

* testing simple way

* testing simple way

* testing simple way

* testing simple way

* actual code

* actual code

* final testing with serialization

* add multi_gpu speech generation

* fix some comments

* fix some style and quality
  • Loading branch information
dame-cell authored Oct 10, 2024
1 parent fd9880d commit f1f2b4d
Showing 1 changed file with 234 additions and 0 deletions.
234 changes: 234 additions & 0 deletions examples/inference/distributed/distributed_speech_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import pathlib
import queue
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import fire
import scipy.io.wavfile
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, VitsModel

from accelerate import PartialState
from accelerate.utils import tqdm


"""
Requirements: transformers accelerate fire scipy datasets
pip install transformers accelerate fire scipy datasets
Example usage:
accelerate launch distributed_speech_generation.py --output_path outputs --batch_size 8 --num_workers 2 --dataset_split train
"""

"""
To run the speech generation
import scipy.io.wavfile
import numpy as np
from IPython.display import Audio
sample_rate, audio_data = scipy.io.wavfile.read('path_to_you_wav_file.wav')
audio_data = audio_data.astype(np.float32) / 32762.0
Audio(audio_data, rate=sample_rate)
"""


def load_pokemon_data(split: str, max_text_length: int):
"""Load Pokemon descriptions from the dataset"""
ds = load_dataset("svjack/pokemon-blip-captions-en-zh", split=split)

# Create dataset of dictionaries
dataset = []
for idx, text in enumerate(ds["en_text"]):
if len(text.strip()) > 0: # Skip empty descriptions
dataset.append(
{
"id": f"pokemon_{idx:06d}",
"text": text.strip()[:max_text_length], # Truncate long descriptions
"original_text": text.strip(), # Keep original for metadata
}
)
return dataset


class ExistsFilter:
def __init__(self, output_dir: Union[pathlib.Path, str]):
current_files = [f.split(".wav")[0] for f in os.listdir(output_dir) if f.endswith(".wav")]
self.processed_files = set(current_files)
print(f"Existing audio files found: {len(self.processed_files)}.")

def __call__(self, x):
return x["id"] not in self.processed_files


def preprocess_fn(sample, tokenizer, max_text_length: int):
inputs = tokenizer(sample["text"], padding=False, truncation=True, max_length=max_text_length, return_tensors="pt")

return {
"input_ids": inputs["input_ids"][0].tolist(),
"attention_mask": inputs["attention_mask"][0].tolist(),
"id": sample["id"],
"text": sample["text"],
"original_text": sample["original_text"],
}


def collate_fn(examples, tokenizer):
"""Collate batch of examples with proper padding"""
# Find max length in this batch
max_length = max(len(example["input_ids"]) for example in examples)

# Pad sequences to max_length
input_ids_list = []
attention_mask_list = []

for example in examples:
# Get current lengths
curr_len = len(example["input_ids"])
padding_length = max_length - curr_len

# Pad sequences
padded_input_ids = example["input_ids"] + [tokenizer.pad_token_id] * padding_length
padded_attention_mask = example["attention_mask"] + [0] * padding_length

input_ids_list.append(padded_input_ids)
attention_mask_list.append(padded_attention_mask)

# Convert to tensors
input_ids = torch.tensor(input_ids_list, dtype=torch.long)
attention_mask = torch.tensor(attention_mask_list, dtype=torch.long)

ids = [example["id"] for example in examples]
texts = [example["text"] for example in examples]
original_texts = [example["original_text"] for example in examples]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"ids": ids,
"texts": texts,
"original_texts": original_texts,
}


def create_dataloader(dataset, batch_size, distributed_state, tokenizer):
"""Create dataloader with preprocessing"""
processed_dataset = [preprocess_fn(item, tokenizer, max_text_length=200) for item in dataset]

# Split dataset for distributed processing
if distributed_state.num_processes > 1:
chunk_size = len(processed_dataset) // distributed_state.num_processes
start_idx = distributed_state.process_index * chunk_size
end_idx = (
start_idx + chunk_size
if distributed_state.process_index < distributed_state.num_processes - 1
else len(processed_dataset)
)
processed_dataset = processed_dataset[start_idx:end_idx]

# Create batches
batches = []
for i in range(0, len(processed_dataset), batch_size):
batch = processed_dataset[i : i + batch_size]
batches.append(collate_fn(batch, tokenizer))
return batches


def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, sampling_rate: int):
while True:
try:
item = output_queue.get(timeout=5)
if item is None:
break
waveforms, ids, texts, original_texts = item

# Save each audio file and its metadata
for waveform, file_id, text, original_text in zip(waveforms, ids, texts, original_texts):
# Save audio
wav_path = output_dir / f"{file_id}.wav"
scipy.io.wavfile.write(wav_path, rate=sampling_rate, data=waveform.cpu().float().numpy())

# Save metadata with both truncated and original text
metadata = {
"text_used": text,
"original_text": original_text,
"model": "facebook/mms-tts-eng",
"sampling_rate": sampling_rate,
}
metadata_path = output_dir / f"{file_id}_metadata.json"
with metadata_path.open("w") as f:
json.dump(metadata, f, indent=4)

except queue.Empty:
continue


def main(
output_path: str = "speech_data",
batch_size: int = 8,
num_workers: int = 2,
dataset_split: str = "train",
model_name: str = "facebook/mms-tts-eng",
max_text_length: int = 200,
):
output_dir = pathlib.Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
distributed_state = PartialState()

# Load model and tokenizer
model = VitsModel.from_pretrained(
model_name,
device_map=distributed_state.device,
torch_dtype=torch.float32,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load and filter data
dataset = load_pokemon_data(dataset_split, max_text_length)
exist_filter = ExistsFilter(output_dir)
dataset = [item for item in dataset if exist_filter(item)]

distributed_state.print(f"Processing {len(dataset)} Pokemon descriptions")

# Create dataloader
batches = create_dataloader(dataset, batch_size, distributed_state, tokenizer)

# Setup output queue and save thread
output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
save_future = save_thread.submit(save_results, output_queue, output_dir, model.config.sampling_rate)

try:
for batch in tqdm(batches, desc="Generating Pokemon descriptions"):
with torch.no_grad():
outputs = model(
input_ids=batch["input_ids"].to(distributed_state.device, dtype=torch.long),
attention_mask=batch["attention_mask"].to(distributed_state.device, dtype=torch.long),
).waveform

output_queue.put((outputs, batch["ids"], batch["texts"], batch["original_texts"]))
finally:
output_queue.put(None)
save_thread.shutdown(wait=True)

save_future.result()


if __name__ == "__main__":
fire.Fire(main)

0 comments on commit f1f2b4d

Please sign in to comment.