Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache embeddings transformations individual from other preprocessing #9

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion birdset/datamodule/beans_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,4 @@ def _preprocess_data(self, dataset):
load_from_cache_file=True,
num_proc=self.dataset_config.n_workers,
)

return dataset
68 changes: 24 additions & 44 deletions birdset/datamodule/embedding_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from birdset.datamodule.components.transforms import EmbeddingTransforms
from birdset.datamodule.base_datamodule import BaseDataModuleHF
from birdset.configs import NetworkConfig, DatasetConfig, LoadersConfig
from datasets import DatasetDict, Dataset, concatenate_datasets
from datasets import DatasetDict, Dataset, concatenate_datasets, load_from_disk
from dataclasses import asdict
from collections import defaultdict
from tabulate import tabulate
from birdset.utils import pylogger
Expand Down Expand Up @@ -69,29 +70,29 @@ def __init__(
self.embedding_model.eval() # Set the model to evaluation mode
self.sampling_rate = embedding_model.sampling_rate
self.max_length = embedding_model.length
self.disk_save_path = os.path.join(
self.embeddings_save_path = os.path.join(
self.dataset_config.data_dir,
f"{self.dataset_config.dataset_name}_processed_{self.dataset_config.seed}_{self.embedding_model_name}_{self.k_samples}_{self.average}_{self.val_batches}_{self.low_train}_{self.sampling_rate}_{self.max_length}",
f"{self.dataset_config.dataset_name}_processed_embedding_model_{self.embedding_model_name}_{self.average}_{self.sampling_rate}_{self.max_length}",
)
log.info(f"Using embedding model:{embedding_model.model_name} (Sampling Rate:{self.sampling_rate}, Window Size:{self.max_length})")

def prepare_data(self):
"""
Same as prepare_data in BaseDataModuleHF but checks if path exists and skips rest otherwise
"""
if not self._prepare_done and os.path.exists(self.disk_save_path):
#! We need to set self.len_trainset otherwise base_module doesn't work so we load train split to get the length
self.len_trainset = len(self._get_dataset("train"))
self._prepare_done = True

log.info("Check if preparing has already been done.")
if self._prepare_done:
log.info("Skip preparing.")
return
# Check if the embeddings for the dataset have already been computed
if os.path.exists(self.embeddings_save_path):
log.info(f"Embeddings found in {self.embeddings_save_path}, loading from disk")
dataset = load_from_disk(self.embeddings_save_path)
else:
log.info("Prepare Data")
dataset = self._load_data()
dataset = self._compute_embeddings(dataset)

log.info("Prepare Data")

dataset = self._load_data()
dataset = self._preprocess_data(dataset)
dataset = self._create_splits(dataset)

Expand All @@ -109,7 +110,6 @@ def _preprocess_data(self, dataset):

# Check if actually a dict
dataset = self._ksamples(dataset)
dataset = self._compute_embeddings(dataset)

if self.dataset_config.task == 'multilabel':
log.info(">> One-hot-encode classes")
Expand Down Expand Up @@ -197,8 +197,7 @@ def _ksamples(self, dataset):
if self.low_train:
del dataset['train']
dataset['train'] = dataset['train_low']

del dataset['train_low']
del dataset['train_low']



Expand All @@ -209,7 +208,7 @@ def _ksamples(self, dataset):

def _compute_embeddings(self, dataset):
"""
Compute Embeddings for the entire dataset and update the dataset in place.
Compute Embeddings for the entire dataset and store them in a new DatasetDict to disk. If the embeddings have already been computed, the dataset will be loaded from disk.
"""
# Define the function that will be applied to each sample
def compute_and_update_embedding(sample):
Expand All @@ -221,15 +220,21 @@ def compute_and_update_embedding(sample):
sample['embedding']['array'] = embedding.squeeze(0).cpu().numpy()
sample.pop('audio') # Remove audio to save space
return sample

def get_new_fingerprint(split):
old_fingerprint = dataset[split]._fingerprint
return f"{old_fingerprint}_embedding_model_{self.embedding_model_name}_{self.average}_{self.sampling_rate}_{self.max_length}"

# Apply the transformation to each split in the dataset
for split in dataset.keys():
log.info(f">> Extracting Embeddings for {split} Split")
# Apply the embedding function to each sample in the split
dataset[split] = dataset[split].map(compute_and_update_embedding, desc="Extracting Embeddings")

return dataset
dataset[split] = dataset[split].map(compute_and_update_embedding, desc="Extracting Embeddings", load_from_cache_file=True, new_fingerprint=get_new_fingerprint(split), num_proc=self.dataset_config.n_workers)

log.info(f"Saving emebeddings to disk: {self.embeddings_save_path}")
dataset.save_to_disk(self.embeddings_save_path)

return dataset

def _get_embedding(self, audio):
# Get waveform and sampling rate
Expand Down Expand Up @@ -286,29 +291,4 @@ def _frame_and_average(self, audio):
# Average the embeddings
averaged_embedding = embeddings.mean(dim=0)

return averaged_embedding

def _save_dataset_to_disk(self, dataset: Dataset | DatasetDict):
"""
Saves the dataset to disk.

This method sets the format of the dataset to numpy, prepares the path where the dataset will be saved, and saves
the dataset to disk. If the dataset already exists on disk, it does not save the dataset again.

Args:
dataset (datasets.DatasetDict): The dataset to be saved. The dataset should be a Hugging Face `datasets.DatasetDict` object.

Returns:
None
"""

#! Due to randomness somewhere in model,resampling,... the fingerprint isn't the same
# For now now we add k_samples, model_name (Potentially k_Samples could also be picked from extracted embeddings already but this is easier and it doesn't take that long)
#! Manual deleting may be needed if you want to recompute the embeddings

dataset.set_format("np") # Removes slight changes as dtype should be same
if os.path.exists(self.disk_save_path):
log.info(f"Train fingerprint found in {self.disk_save_path}, saving to disk is skipped")
else:
log.info(f"Saving to disk: {self.disk_save_path}")
dataset.save_to_disk(self.disk_save_path)
return averaged_embedding