From bd44bd6bfbe19d9c16b399daceb73cbc545222ea Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 21:03:40 -0400 Subject: [PATCH 1/7] Relabel "batches" in ParquetDataset, add samplers * Renamed the batch variables in ParquetDataset to chunk variables * Implemented RandomChunkSampler and LenMatchBatchSampler w/ modifications --- src/graphnet/data/dataset/__init__.py | 1 + .../data/dataset/parquet/parquet_dataset.py | 28 ++- src/graphnet/data/dataset/samplers.py | 232 ++++++++++++++++++ 3 files changed, 250 insertions(+), 11 deletions(-) create mode 100644 src/graphnet/data/dataset/samplers.py diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index f6eafee94..eeb3123d9 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -5,6 +5,7 @@ if has_torch_package(): import torch.multiprocessing from .dataset import EnsembleDataset, Dataset, ColumnMissingException + from .samplers import RandomChunkSampler, LenMatchBatchSampler from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 3561c591a..2df6ed16e 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -5,6 +5,7 @@ List, Optional, Union, + Any, ) import numpy as np @@ -92,7 +93,7 @@ def __init__( `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). graph_definition: Method that defines the graph representation. - cache_size: Number of batches to cache in memory. + cache_size: Number of files to cache in memory. Must be at least 1. Defaults to 1. labels: Dictionary of labels to be added to the dataset. """ @@ -123,8 +124,8 @@ def __init__( self._path: str = self._path # Member Variables self._cache_size = cache_size - self._batch_sizes = self._calculate_sizes() - self._batch_cumsum = np.cumsum(self._batch_sizes) + self._chunk_sizes = self._calculate_sizes() + self._chunk_cumsum = np.cumsum(self._chunk_sizes) self._file_cache = self._initialize_file_cache( truth_table=truth_table, node_truth_table=node_truth_table, @@ -179,9 +180,14 @@ def _get_event_index(self, sequential_index: int) -> int: ) return event_index + @property + def chunk_sizes(self) -> List[int]: + """Return a list of the chunk sizes.""" + return self._chunk_sizes + def __len__(self) -> int: """Return length of dataset, i.e. number of training examples.""" - return sum(self._batch_sizes) + return sum(self._chunk_sizes) def _get_all_indices(self) -> List[int]: """Return a list of all unique values in `self._index_column`.""" @@ -189,22 +195,22 @@ def _get_all_indices(self) -> List[int]: return np.arange(0, len(files), 1) def _calculate_sizes(self) -> List[int]: - """Calculate the number of events in each batch.""" + """Calculate the number of events in each chunk.""" sizes = [] - for batch_id in self._indices: + for chunk_id in self._indices: path = os.path.join( self._path, self._truth_table, - f"{self.truth_table}_{batch_id}.parquet", + f"{self.truth_table}_{chunk_id}.parquet", ) sizes.append(len(pol.read_parquet(path))) return sizes def _get_row_idx(self, sequential_index: int) -> int: """Return the row index corresponding to a `sequential_index`.""" - file_idx = bisect_right(self._batch_cumsum, sequential_index) + file_idx = bisect_right(self._chunk_cumsum, sequential_index) if file_idx > 0: - idx = int(sequential_index - self._batch_cumsum[file_idx - 1]) + idx = int(sequential_index - self._chunk_cumsum[file_idx - 1]) else: idx = sequential_index return idx @@ -241,9 +247,9 @@ def query_table( # type: ignore columns = [columns] if sequential_index is None: - file_idx = np.arange(0, len(self._batch_cumsum), 1) + file_idx = np.arange(0, len(self._chunk_cumsum), 1) else: - file_idx = [bisect_right(self._batch_cumsum, sequential_index)] + file_idx = [bisect_right(self._chunk_cumsum, sequential_index)] file_indices = [self._indices[idx] for idx in file_idx] diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py new file mode 100644 index 000000000..0d102898b --- /dev/null +++ b/src/graphnet/data/dataset/samplers.py @@ -0,0 +1,232 @@ +"""`Sampler` and `BatchSampler` objects for `graphnet`.""" +from typing import ( + Any, + List, + Optional, + Tuple, + Iterator, + Sequence, +) + +from collections import defaultdict +from multiprocessing import Pool, cpu_count, get_context + +import numpy as np +import torch +from torch.utils.data import Sampler, BatchSampler + + +class RandomChunkSampler(Sampler[int]): + """A `Sampler` that randomly selects chunks. + + MIT License + + Copyright (c) 2023 DrHB + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__( + self, + data_source: Sequence[Any], + chunks: List[int], + num_samples: Optional[int] = None, + generator: Optional[torch.Generator] = None, + ) -> None: + """Construct `RandomChunkSampler`.""" + # chunks - a list of chunk sizes + self._data_source = data_source + self._num_samples = num_samples + self._chunks = chunks + + # Create a random number generator if one was not provided + if generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._generator = torch.Generator() + self._generator.manual_seed(seed) + else: + self._generator = generator + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + "num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples) + ) + + @property + def data_source(self) -> Sequence[Any]: + """Return the data source.""" + return self._data_source + + @property + def num_samples(self) -> int: + """Return the number of samples in the data source.""" + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __len__(self) -> int: + """Return the number of sampled.""" + return self.num_samples + + @property + def chunks(self) -> List[int]: + """Return the list of chunks.""" + return self._chunks + + def __iter__(self) -> Iterator[List[int]]: + """Return a list of indices from a randomly sampled chunk.""" + cumsum = np.cumsum(self.chunks) + chunk_list = torch.randperm( + len(self.chunks), generator=self.generator + ).tolist() + + # sample indexes chunk by chunk + yield_samples = 0 + for i in chunk_list: + chunk_len = self.chunks[i] + offset = cumsum[i - 1] if i > 0 else 0 + samples = ( + offset + torch.randperm(chunk_len, generator=self.generator) + ).tolist() + if len(samples) <= self.num_samples - yield_samples: + yield_samples += len(samples) + else: + samples = samples[: self.num_samples - yield_samples] + yield_samples = self.num_samples + yield from samples + + +def gather_buckets( + params: Tuple[List[int], Sequence[Any], int], +) -> Tuple[List[List[int]], List[List[int]]]: + """Gather buckets of events. + + The function that will be used to gather buckets of events by the + `LenMatchBatchSampler`. When using multiprocessing, each worker will call + this function. + + Args: + params: A tuple containg the list of indices to process, + the data_source (typically a `Dataset`), and the batch size. + + Returns: + batches: A list containing batches. + remaining_batches: Incomplete batches. + """ + indices, data_source, batch_size = params + buckets = defaultdict(list) + batches = [] + + for idx in indices: + s = data_source[idx] + L = max(1, s.num_nodes // 16) + buckets[L].append(idx) + if len(buckets[L]) == batch_size: + batches.append(list(buckets[L])) + buckets[L] = [] + + # Include any remaining items in partially filled buckets + remaining_batches = [b for b in buckets.values() if b] + return batches, remaining_batches + + +class LenMatchBatchSampler(BatchSampler): + """A `BatchSampler` that batches similar length events. + + MIT License + + Copyright (c) 2023 DrHB + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__( + self, + sampler: Sampler, + batch_size: int, + drop_last: Optional[bool] = False, + ) -> None: + """Construct `LenMatchBatchSampler`.""" + super().__init__( + sampler=sampler, batch_size=batch_size, drop_last=drop_last + ) + + def __iter__(self) -> Iterator[List[int]]: + """Return length-matched batches.""" + indices = list(self.sampler) + data_source = self.sampler.data_source + + n_workers = min(cpu_count(), 6) + chunk_size = len(indices) // n_workers + + # Split indices into nearly equal-sized chunks + chunks = [ + indices[i * chunk_size : (i + 1) * chunk_size] + for i in range(n_workers) + ] + if len(indices) % n_workers != 0: + chunks.append(indices[n_workers * chunk_size :]) + + yielded = 0 + with get_context("spawn").Pool(processes=n_workers) as pool: + results = pool.map( + gather_buckets, + [(chunk, data_source, self.batch_size) for chunk in chunks], + ) + + merged_batches = [] + remaining_indices = [] + for batches, remaining in results: + merged_batches.extend(batches) + remaining_indices.extend(remaining) + + for batch in merged_batches: + yield batch + yielded += 1 + + # Process any remaining indices + leftover = [idx for batch in remaining_indices for idx in batch] + batch = [] + for idx in leftover: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + yielded += 1 + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch + yielded += 1 From 0702c474e8202f0119f905bdaa694a11c00912ab Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:18:37 -0400 Subject: [PATCH 2/7] Implementation of Samplers and BatchSamplers into GraphNeTDataModule --- src/graphnet/data/datamodule.py | 40 ++++++++++++++++++- src/graphnet/data/dataset/samplers.py | 55 +++++++++++++++++---------- 2 files changed, 74 insertions(+), 21 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 33f31c5fe..7ab2bbe06 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -195,6 +195,13 @@ def setup(self, stage: str) -> None: if self._val_selection is not None: self._val_dataset = self._create_dataset(self._val_selection) + # if self._len_match_batch: # TODO: the same for val -PW + # batch_size = self._train_batch_sampler_kwargs["batch_size"] + # self._train_random_chunk_sampler = RandomChunkSampler(self._train_dataset, + # chunks=self._train_dataset.chunk_sizes) + # self._train_len_match_batch_sampler = LenMatchBatchSampler(self._train_random_chunk_sampler, + # batch_size=batch_size, + # drop_last=True) return @property @@ -273,6 +280,38 @@ def _create_dataloader( "Unknown dataset encountered during dataloader creation." ) + if "sampler" in dataloader_args.keys(): + # If there were no kwargs provided, set it to empty dict + if "sampler_kwargs" not in dataloader_args.keys(): + dataloader_args["sampler_kwargs"] = {} + dataloader_args["sampler"] = dataloader_args["sampler"]( + dataset, **dataloader_args["sampler_kwargs"] + ) + del dataloader_args["sampler_kwargs"] + + if "batch_sampler" in dataloader_args.keys(): + if "sampler" not in dataloader_args.keys(): + raise KeyError( + "When specifying a `batch_sampler`, you must also provide `sampler`." + ) + # If there were no kwargs provided, set it to empty dict + if "batch_sampler_kwargs" not in dataloader_args.keys(): + dataloader_args["batch_sampler_kwargs"] = {} + + batch_sampler = dataloader_args["batch_sampler"]( + dataloader_args["sampler"], + **dataloader_args["batch_sampler_kwargs"], + ) + dataloader_args["batch_sampler"] = batch_sampler + # Remove extra keys + for key in [ + "batch_sampler_kwargs", + "drop_last", + "sampler", + "shuffle", + ]: + dataloader_args.pop(key, None) + if dataloader_args is None: raise AttributeError("Dataloader arguments not provided.") @@ -479,7 +518,6 @@ def _infer_selections_on_single_dataset( .sample(frac=1, replace=False, random_state=self._rng) .values.tolist() ) # shuffled list - return self._split_selection(all_events) def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 0d102898b..7a264088a 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -14,6 +14,7 @@ import numpy as np import torch from torch.utils.data import Sampler, BatchSampler +from graphnet.data.dataset import Dataset class RandomChunkSampler(Sampler[int]): @@ -40,20 +41,21 @@ class RandomChunkSampler(Sampler[int]): LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + _____________________ + + Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( self, - data_source: Sequence[Any], - chunks: List[int], + data_source: Dataset, num_samples: Optional[int] = None, generator: Optional[torch.Generator] = None, ) -> None: """Construct `RandomChunkSampler`.""" - # chunks - a list of chunk sizes self._data_source = data_source self._num_samples = num_samples - self._chunks = chunks + self._chunks = data_source.chunk_sizes # Create a random number generator if one was not provided if generator is None: @@ -94,7 +96,7 @@ def __iter__(self) -> Iterator[List[int]]: """Return a list of indices from a randomly sampled chunk.""" cumsum = np.cumsum(self.chunks) chunk_list = torch.randperm( - len(self.chunks), generator=self.generator + len(self.chunks), generator=self._generator ).tolist() # sample indexes chunk by chunk @@ -103,7 +105,7 @@ def __iter__(self) -> Iterator[List[int]]: chunk_len = self.chunks[i] offset = cumsum[i - 1] if i > 0 else 0 samples = ( - offset + torch.randperm(chunk_len, generator=self.generator) + offset + torch.randperm(chunk_len, generator=self._generator) ).tolist() if len(samples) <= self.num_samples - yield_samples: yield_samples += len(samples) @@ -114,29 +116,33 @@ def __iter__(self) -> Iterator[List[int]]: def gather_buckets( - params: Tuple[List[int], Sequence[Any], int], + params: Tuple[List[int], Sequence[Any], int, int], ) -> Tuple[List[List[int]], List[List[int]]]: """Gather buckets of events. - The function that will be used to gather buckets of events by the + The function that will be used to gather batches of events for the `LenMatchBatchSampler`. When using multiprocessing, each worker will call - this function. + this function. Given indices, this function will group events based on + their length. If the length of event is N, then it will go into the + (N // bucket_width) bucket. This returns completed batches and a + list of incomplete batches that did not fill to batch_size at the end. Args: params: A tuple containg the list of indices to process, - the data_source (typically a `Dataset`), and the batch size. + the data_source (typically a `Dataset`), the batch size, and the + bucket width. Returns: batches: A list containing batches. remaining_batches: Incomplete batches. """ - indices, data_source, batch_size = params + indices, data_source, batch_size, bucket_width = params buckets = defaultdict(list) batches = [] for idx in indices: s = data_source[idx] - L = max(1, s.num_nodes // 16) + L = max(1, s.num_nodes // bucket_width) buckets[L].append(idx) if len(buckets[L]) == batch_size: batches.append(list(buckets[L])) @@ -171,40 +177,49 @@ class LenMatchBatchSampler(BatchSampler): LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + _____________________ + + Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( self, sampler: Sampler, - batch_size: int, + batch_size: int = 1, + num_workers: int = 1, + bucket_width: int = 16, drop_last: Optional[bool] = False, ) -> None: """Construct `LenMatchBatchSampler`.""" super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) + self._bucket_width = bucket_width + self._num_workers = num_workers def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" indices = list(self.sampler) data_source = self.sampler.data_source - n_workers = min(cpu_count(), 6) - chunk_size = len(indices) // n_workers + chunk_size = len(indices) // self._num_workers # Split indices into nearly equal-sized chunks chunks = [ indices[i * chunk_size : (i + 1) * chunk_size] - for i in range(n_workers) + for i in range(self._num_workers) ] - if len(indices) % n_workers != 0: - chunks.append(indices[n_workers * chunk_size :]) + if len(indices) % self._num_workers != 0: + chunks.append(indices[self._num_workers * chunk_size :]) yielded = 0 - with get_context("spawn").Pool(processes=n_workers) as pool: + with get_context("spawn").Pool(processes=self._num_workers) as pool: results = pool.map( gather_buckets, - [(chunk, data_source, self.batch_size) for chunk in chunks], + [ + (chunk, data_source, self.batch_size, self._bucket_width) + for chunk in chunks + ], ) merged_batches = [] From a570f7b93fd24fd7cbe1627275f365969d056875 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:20:48 -0400 Subject: [PATCH 3/7] Remove old comment block --- src/graphnet/data/datamodule.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 7ab2bbe06..ae3737ebd 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -195,13 +195,6 @@ def setup(self, stage: str) -> None: if self._val_selection is not None: self._val_dataset = self._create_dataset(self._val_selection) - # if self._len_match_batch: # TODO: the same for val -PW - # batch_size = self._train_batch_sampler_kwargs["batch_size"] - # self._train_random_chunk_sampler = RandomChunkSampler(self._train_dataset, - # chunks=self._train_dataset.chunk_sizes) - # self._train_len_match_batch_sampler = LenMatchBatchSampler(self._train_random_chunk_sampler, - # batch_size=batch_size, - # drop_last=True) return @property From 3132ac04854aae99fbe2ab47098b7fb73ec0895e Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:31:57 -0400 Subject: [PATCH 4/7] Add multiprocessing_context for LenMatchBatchSampler --- src/graphnet/data/dataset/samplers.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 7a264088a..a9938a968 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -188,14 +188,33 @@ def __init__( batch_size: int = 1, num_workers: int = 1, bucket_width: int = 16, + multiprocessing_context: str = "spawn", drop_last: Optional[bool] = False, ) -> None: - """Construct `LenMatchBatchSampler`.""" + """Construct `LenMatchBatchSampler`. + + This `BatchSampler` groups data with similar lengths to be more efficient + in operations like masking for MultiHeadAttention. Since batch samplers + run on the main process and can result in a CPU bottleneck, `num_workers` + can be specified to use multiprocessing for creating the batches. The + `bucket_width` argument specifies how wide the bins are for grouping batches. + For example, with `bucket_width=16`, data with length [1, 16] and grouped into + a bucket and data with length [17, 32] in another. + + Args: + sampler: A `Sampler` object that selects/draws data in some way. + batch_size: Batch size. + num_workers: Number of workers to spawn to create batches. + bucket_width: Size of length buckets for grouping data. + multiprocessing_context: Start method for multiprocessing. + drop_last: (Optional) Drop the last incomplete batch. + """ super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) - self._bucket_width = bucket_width self._num_workers = num_workers + self._bucket_width = bucket_width + self._multiprocessing_context = multiprocessing_context def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" @@ -213,7 +232,9 @@ def __iter__(self) -> Iterator[List[int]]: chunks.append(indices[self._num_workers * chunk_size :]) yielded = 0 - with get_context("spawn").Pool(processes=self._num_workers) as pool: + with get_context(self._multiprocessing_context).Pool( + processes=self._num_workers + ) as pool: results = pool.map( gather_buckets, [ From c261fc922042635cec3aa9ddd6e94a95370e58ac Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Fri, 16 Aug 2024 00:12:07 -0400 Subject: [PATCH 5/7] Improved LenMatchBatchSampler --- src/graphnet/data/dataset/samplers.py | 53 +++++++++++++-------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index a9938a968..45619b41b 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -198,8 +198,8 @@ def __init__( run on the main process and can result in a CPU bottleneck, `num_workers` can be specified to use multiprocessing for creating the batches. The `bucket_width` argument specifies how wide the bins are for grouping batches. - For example, with `bucket_width=16`, data with length [1, 16] and grouped into - a bucket and data with length [17, 32] in another. + For example, with `bucket_width=16`, data with length [1, 16] are grouped into + a bucket, data with length [17, 32] into another, etc. Args: sampler: A `Sampler` object that selects/draws data in some way. @@ -212,6 +212,10 @@ def __init__( super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) + assert ( + num_workers >= 1 + ), "Need at least one worker to use LenMatchBatchSampler!" + self._num_workers = num_workers self._bucket_width = bucket_width self._multiprocessing_context = multiprocessing_context @@ -221,43 +225,38 @@ def __iter__(self) -> Iterator[List[int]]: indices = list(self.sampler) data_source = self.sampler.data_source - chunk_size = len(indices) // self._num_workers + segments_size = len(indices) // self._num_workers - # Split indices into nearly equal-sized chunks - chunks = [ - indices[i * chunk_size : (i + 1) * chunk_size] + # Split indices into nearly equal-sized segments amonst the workers + segments = [ + indices[i * segments_size : (i + 1) * segments_size] for i in range(self._num_workers) ] + + # Collect the leftovers into another segment if len(indices) % self._num_workers != 0: - chunks.append(indices[self._num_workers * chunk_size :]) + segments.append(indices[self._num_workers * segments_size :]) yielded = 0 - with get_context(self._multiprocessing_context).Pool( - processes=self._num_workers - ) as pool: - results = pool.map( + remaining_indices = [] + with get_context("spawn").Pool(processes=self._num_workers) as pool: + for result in pool.imap_unordered( gather_buckets, [ - (chunk, data_source, self.batch_size, self._bucket_width) - for chunk in chunks + (segment, data_source, self.batch_size, self._bucket_width) + for segment in segments ], - ) - - merged_batches = [] - remaining_indices = [] - for batches, remaining in results: - merged_batches.extend(batches) - remaining_indices.extend(remaining) - - for batch in merged_batches: - yield batch - yielded += 1 + ): + batches, leftovers = result + for batch in batches: + yield batch + yielded += 1 + remaining_indices.extend(leftovers) # Process any remaining indices - leftover = [idx for batch in remaining_indices for idx in batch] batch = [] - for idx in leftover: - batch.append(idx) + for incomplete_batch in remaining_indices: + batch.extend(incomplete_batch) if len(batch) == self.batch_size: yield batch yielded += 1 From 9b6e5688e0e06856d465567a1b99cad459e4f1e9 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 22 Aug 2024 08:17:36 -0400 Subject: [PATCH 6/7] Minor changes, more settings for samplers --- src/graphnet/data/dataset/__init__.py | 5 +- src/graphnet/data/dataset/samplers.py | 133 +++++++++++++++++--------- 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index eeb3123d9..ed1c55ef5 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -5,7 +5,10 @@ if has_torch_package(): import torch.multiprocessing from .dataset import EnsembleDataset, Dataset, ColumnMissingException - from .samplers import RandomChunkSampler, LenMatchBatchSampler + from .samplers import ( + RandomChunkSampler, + LenMatchBatchSampler, + ) from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 45619b41b..ae8f728fb 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -15,6 +15,7 @@ import torch from torch.utils.data import Sampler, BatchSampler from graphnet.data.dataset import Dataset +from graphnet.utilities.logging import Logger class RandomChunkSampler(Sampler[int]): @@ -115,10 +116,10 @@ def __iter__(self) -> Iterator[List[int]]: yield from samples -def gather_buckets( - params: Tuple[List[int], Sequence[Any], int, int], +def gather_len_matched_buckets( + params: Tuple[range, Sequence[Any], int, int], ) -> Tuple[List[List[int]], List[List[int]]]: - """Gather buckets of events. + """Gather length-matched buckets of events. The function that will be used to gather batches of events for the `LenMatchBatchSampler`. When using multiprocessing, each worker will call @@ -153,7 +154,7 @@ def gather_buckets( return batches, remaining_batches -class LenMatchBatchSampler(BatchSampler): +class LenMatchBatchSampler(BatchSampler, Logger): """A `BatchSampler` that batches similar length events. MIT License @@ -188,6 +189,7 @@ def __init__( batch_size: int = 1, num_workers: int = 1, bucket_width: int = 16, + chunks_per_segment: int = 4, multiprocessing_context: str = "spawn", drop_last: Optional[bool] = False, ) -> None: @@ -206,62 +208,103 @@ def __init__( batch_size: Batch size. num_workers: Number of workers to spawn to create batches. bucket_width: Size of length buckets for grouping data. + chunks_per_segment: Number of chunks to group together for processing. multiprocessing_context: Start method for multiprocessing. drop_last: (Optional) Drop the last incomplete batch. """ + Logger.__init__(self) super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) - assert ( - num_workers >= 1 - ), "Need at least one worker to use LenMatchBatchSampler!" + assert num_workers >= 0, "`num_workers` must be >= 0!" self._num_workers = num_workers self._bucket_width = bucket_width + self._chunks_per_segment = chunks_per_segment self._multiprocessing_context = multiprocessing_context + self.info( + f"Setting up batch sampler with {self._num_workers} workers." + ) + def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" indices = list(self.sampler) data_source = self.sampler.data_source - segments_size = len(indices) // self._num_workers - - # Split indices into nearly equal-sized segments amonst the workers - segments = [ - indices[i * segments_size : (i + 1) * segments_size] - for i in range(self._num_workers) - ] - - # Collect the leftovers into another segment - if len(indices) % self._num_workers != 0: - segments.append(indices[self._num_workers * segments_size :]) - - yielded = 0 - remaining_indices = [] - with get_context("spawn").Pool(processes=self._num_workers) as pool: - for result in pool.imap_unordered( - gather_buckets, - [ - (segment, data_source, self.batch_size, self._bucket_width) - for segment in segments - ], - ): - batches, leftovers = result - for batch in batches: - yield batch - yielded += 1 - remaining_indices.extend(leftovers) - - # Process any remaining indices - batch = [] - for incomplete_batch in remaining_indices: - batch.extend(incomplete_batch) - if len(batch) == self.batch_size: + if self._num_workers > 0: + + n_chunks = len(self.sampler.chunks) + n_segments = n_chunks // self._chunks_per_segment + + # Split indices into nearly equal-sized segments amongst the workers + segments = [ + range( + sum(self.sampler.chunks[: i * self._chunks_per_segment]), + sum( + self.sampler.chunks[ + : (i + 1) * self._chunks_per_segment + ] + ), + ) + for i in range(n_segments) + ] + segments.extend( + [range(segments[-1][-1], len(indices) - 1)] + ) # Make a segment w/ the leftover indices + + remaining_indices = [] + with get_context(self._multiprocessing_context).Pool( + processes=self._num_workers + ) as pool: + results = pool.imap_unordered( + gather_len_matched_buckets, + [ + ( + segments[i], + data_source, + self.batch_size, + self._bucket_width, + ) + for i in range(n_segments) + ], + ) + for result in results: + batches, leftovers = result + for batch in batches: + yield batch + remaining_indices.extend(leftovers) + + # Process any remaining indices + batch = [] + for incomplete_batch in remaining_indices: + batch.extend(incomplete_batch) + if len(batch) >= self.batch_size: + yield batch[: self.batch_size] + batch = batch[self.batch_size :] + + if len(batch) > 0 and not self.drop_last: yield batch - yielded += 1 - batch = [] + else: # n_workers = 0, no multiprocessing + buckets = defaultdict(list) + + for idx in self.sampler: + s = self.sampler.data_source[idx] + L = max(1, s.num_nodes // self._bucket_width) + buckets[L].append(idx) + if len(buckets[L]) == self.batch_size: + batch = list(buckets[L]) + yield batch + buckets[L] = [] - if len(batch) > 0 and not self.drop_last: - yield batch - yielded += 1 + batch = [] + leftover = [idx for bucket in buckets for idx in bucket] + + for idx in leftover: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch From d6e7bed3af1734b7356f3ca5f8a3fed8d3eed4d2 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Mon, 2 Sep 2024 11:09:22 -0400 Subject: [PATCH 7/7] Fix docstrings --- src/graphnet/data/datamodule.py | 3 +- src/graphnet/data/dataset/samplers.py | 98 +++++++++++---------------- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index ae3737ebd..802a64a7d 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -285,7 +285,8 @@ def _create_dataloader( if "batch_sampler" in dataloader_args.keys(): if "sampler" not in dataloader_args.keys(): raise KeyError( - "When specifying a `batch_sampler`, you must also provide `sampler`." + "When specifying a `batch_sampler`," + "you must also provide `sampler`." ) # If there were no kwargs provided, set it to empty dict if "batch_sampler_kwargs" not in dataloader_args.keys(): diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index ae8f728fb..c43455447 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -1,4 +1,29 @@ -"""`Sampler` and `BatchSampler` objects for `graphnet`.""" +"""`Sampler` and `BatchSampler` objects for `graphnet`. + +MIT License + +Copyright (c) 2023 DrHB + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +_____________________ +""" + from typing import ( Any, List, @@ -21,30 +46,8 @@ class RandomChunkSampler(Sampler[int]): """A `Sampler` that randomly selects chunks. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -157,30 +160,8 @@ def gather_len_matched_buckets( class LenMatchBatchSampler(BatchSampler, Logger): """A `BatchSampler` that batches similar length events. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -195,20 +176,21 @@ def __init__( ) -> None: """Construct `LenMatchBatchSampler`. - This `BatchSampler` groups data with similar lengths to be more efficient - in operations like masking for MultiHeadAttention. Since batch samplers - run on the main process and can result in a CPU bottleneck, `num_workers` - can be specified to use multiprocessing for creating the batches. The - `bucket_width` argument specifies how wide the bins are for grouping batches. - For example, with `bucket_width=16`, data with length [1, 16] are grouped into - a bucket, data with length [17, 32] into another, etc. + This `BatchSampler` groups data with similar lengths to be more + efficient in operations like masking for MultiHeadAttention. Since + batch samplers run on the main process and can result in a CPU + bottleneck, `num_workers` can be specified to use multiprocessing for + creating the batches. The `bucket_width` argument specifies how wide + the bins are for grouping batches. For example, with `bucket_width=16`, + data with length [1, 16] are grouped into a bucket, data with length + [17, 32] into another, etc. Args: sampler: A `Sampler` object that selects/draws data in some way. batch_size: Batch size. num_workers: Number of workers to spawn to create batches. bucket_width: Size of length buckets for grouping data. - chunks_per_segment: Number of chunks to group together for processing. + chunks_per_segment: Number of chunks to group together. multiprocessing_context: Start method for multiprocessing. drop_last: (Optional) Drop the last incomplete batch. """ @@ -237,7 +219,7 @@ def __iter__(self) -> Iterator[List[int]]: n_chunks = len(self.sampler.chunks) n_segments = n_chunks // self._chunks_per_segment - # Split indices into nearly equal-sized segments amongst the workers + # Split indices into nearly equal-sized segments amongst workers segments = [ range( sum(self.sampler.chunks[: i * self._chunks_per_segment]),