Skip to content

Commit

Permalink
refactor(examples) Update whisper finetuning example (#4158)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 31, 2025
1 parent 6e736fd commit c007d67
Show file tree
Hide file tree
Showing 18 changed files with 1,012 additions and 814 deletions.
1 change: 1 addition & 0 deletions examples/whisper-federated-finetuning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
processed_partitions/
262 changes: 138 additions & 124 deletions examples/whisper-federated-finetuning/README.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import argparse
import random

import numpy as np
import torch
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, WeightedRandomSampler
from transformers import WhisperForConditionalGeneration, WhisperProcessor

from utils import (
from torch.utils.data import DataLoader
from transformers import WhisperProcessor
from whisper_example.dataset import get_encoding_fn, prepare_silences_dataset
from whisper_example.model import (
construct_balanced_sampler,
eval_model,
get_encoding_fn,
get_model,
prepare_silences_dataset,
remove_cols,
train_one_epoch,
)

from datasets import concatenate_datasets, load_dataset

random.seed(1989)
torch.set_float32_matmul_precision(
"high"
) # If “high” or “medium” are set then the TensorFloat32 is used
NUM_CLASSES = 12
REMOVE_COLS = ["file", "audio", "label", "is_unknown", "speaker_id", "utterance_id"]
parser = argparse.ArgumentParser(description="Whisper centralised")

parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint")
Expand Down Expand Up @@ -56,10 +55,10 @@ def main():
torch.set_num_threads(
1
) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map`
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS)
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS)
test_encoded = sc_test.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS
)

# create and pre-process the dataset of silences
Expand All @@ -68,26 +67,19 @@ def main():
# ! needed each time you run the code. Alternatively, this silence generation could be
# ! implemented as part of a `collate_fn` in the standard PyTorch dataloader...
encoded_silences = silences_dataset.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS
)
full_train_dataset = concatenate_datasets([train_encoded, encoded_silences])

torch.set_num_threads(og_threads)

lbls = set(full_train_dataset["targets"])
print(f"{lbls = }")
hist = np.histogram(full_train_dataset["targets"], bins=12)
print(f"{[int(count) for count in hist[0]]}")

# make balanced batches with a WeightedRandomSampler
w_per_class = (
len(full_train_dataset) / hist[0]
) # doesn't have to add up to 1 (relative is what matters)
print(f"{w_per_class = }")
w_ss = [w_per_class[t] for t in full_train_dataset["targets"]]
sampler = WeightedRandomSampler(w_ss, len(w_ss))

# prepare dataloaders
# Construct a balanced sampler so batches roughly contain the same number
# of samples from each class
sampler = construct_balanced_sampler(full_train_dataset)

# Prepare dataloaders
train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"])
train_loader = DataLoader(
train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler
Expand All @@ -97,7 +89,7 @@ def main():
test_dataset = test_encoded.with_format("torch", columns=["data", "targets"])
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4)

# model to cuda, set criterion, classification layer to train and optimiser
# Model to cuda, set criterion, classification layer to train and optimiser
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder, classifier = get_model(device, num_classes=12)
criterion = torch.nn.CrossEntropyLoss()
Expand All @@ -113,7 +105,7 @@ def main():
classifier_head_params = sum(p.numel() for p in classifier.parameters())
print(f"{classifier_head_params = }")

# eval initial model
# Eval initial model
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device)
print(f"Initial (loss, acc): {loss = }, {accuracy = }")
best = [-float("inf"), None]
Expand Down
185 changes: 0 additions & 185 deletions examples/whisper-federated-finetuning/client.py

This file was deleted.

60 changes: 60 additions & 0 deletions examples/whisper-federated-finetuning/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
from multiprocessing import Pool
from time import time

import tomli
from whisper_example.dataset import load_data

from datasets import load_dataset

parser = argparse.ArgumentParser(description="Whisper preprocessing")

parser.add_argument(
"--partition-id", type=int, help="The partition to create and save."
)

args = parser.parse_args()


# Open and read the pyproject.toml
with open("pyproject.toml", "rb") as file:
flwr_config = tomli.load(file)["tool"]["flwr"]

# Display
print(flwr_config)
remove_cols = flwr_config["app"]["config"]["remove-cols"]
num_supernodes = flwr_config["federations"]["local-sim"]["options"]["num-supernodes"]

# If specified one partition, only that one will be processed and saved to the current directory
if args.partition_id:
print(f"Pre-processing partition {args.partition_id} only.")
else:
print(f"Pre-processing dataset into {num_supernodes} partitions.")


def process_one_partition(partition_id: int, save: bool = False):
pp = load_data(partition_id, remove_cols)
if save:
file_name = f"partition_{partition_id}"
pp.save_to_disk(file_name)
print(f"Saved partition to disk: {file_name}")


if __name__ == "__main__":

# Download train set
_ = load_dataset("speech_commands", "v0.02", split="train", token=False)

# Parallelize the processing of each partition in the dataset
t_start = time()
num_proc = None # set it if you want to limit the number of processes

if args.partition_id:
process_one_partition(args.partition_id, True)

else:
with Pool(num_proc) as pool:
pool.map(process_one_partition, range(num_supernodes))
print(
f"Pre-processing {num_supernodes} partitions took: {time() - t_start:.2f} s"
)
Loading

0 comments on commit c007d67

Please sign in to comment.