diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index ae3737ebd..802a64a7d 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -285,7 +285,8 @@ def _create_dataloader( if "batch_sampler" in dataloader_args.keys(): if "sampler" not in dataloader_args.keys(): raise KeyError( - "When specifying a `batch_sampler`, you must also provide `sampler`." + "When specifying a `batch_sampler`," + "you must also provide `sampler`." ) # If there were no kwargs provided, set it to empty dict if "batch_sampler_kwargs" not in dataloader_args.keys(): diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index ae8f728fb..c43455447 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -1,4 +1,29 @@ -"""`Sampler` and `BatchSampler` objects for `graphnet`.""" +"""`Sampler` and `BatchSampler` objects for `graphnet`. + +MIT License + +Copyright (c) 2023 DrHB + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +_____________________ +""" + from typing import ( Any, List, @@ -21,30 +46,8 @@ class RandomChunkSampler(Sampler[int]): """A `Sampler` that randomly selects chunks. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -157,30 +160,8 @@ def gather_len_matched_buckets( class LenMatchBatchSampler(BatchSampler, Logger): """A `BatchSampler` that batches similar length events. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -195,20 +176,21 @@ def __init__( ) -> None: """Construct `LenMatchBatchSampler`. - This `BatchSampler` groups data with similar lengths to be more efficient - in operations like masking for MultiHeadAttention. Since batch samplers - run on the main process and can result in a CPU bottleneck, `num_workers` - can be specified to use multiprocessing for creating the batches. The - `bucket_width` argument specifies how wide the bins are for grouping batches. - For example, with `bucket_width=16`, data with length [1, 16] are grouped into - a bucket, data with length [17, 32] into another, etc. + This `BatchSampler` groups data with similar lengths to be more + efficient in operations like masking for MultiHeadAttention. Since + batch samplers run on the main process and can result in a CPU + bottleneck, `num_workers` can be specified to use multiprocessing for + creating the batches. The `bucket_width` argument specifies how wide + the bins are for grouping batches. For example, with `bucket_width=16`, + data with length [1, 16] are grouped into a bucket, data with length + [17, 32] into another, etc. Args: sampler: A `Sampler` object that selects/draws data in some way. batch_size: Batch size. num_workers: Number of workers to spawn to create batches. bucket_width: Size of length buckets for grouping data. - chunks_per_segment: Number of chunks to group together for processing. + chunks_per_segment: Number of chunks to group together. multiprocessing_context: Start method for multiprocessing. drop_last: (Optional) Drop the last incomplete batch. """ @@ -237,7 +219,7 @@ def __iter__(self) -> Iterator[List[int]]: n_chunks = len(self.sampler.chunks) n_segments = n_chunks // self._chunks_per_segment - # Split indices into nearly equal-sized segments amongst the workers + # Split indices into nearly equal-sized segments amongst workers segments = [ range( sum(self.sampler.chunks[: i * self._chunks_per_segment]),