Skip to content

Commit

Permalink
Fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
pweigel committed Sep 2, 2024
1 parent 9b6e568 commit d6e7bed
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 59 deletions.
3 changes: 2 additions & 1 deletion src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def _create_dataloader(
if "batch_sampler" in dataloader_args.keys():
if "sampler" not in dataloader_args.keys():
raise KeyError(
"When specifying a `batch_sampler`, you must also provide `sampler`."
"When specifying a `batch_sampler`,"
"you must also provide `sampler`."
)
# If there were no kwargs provided, set it to empty dict
if "batch_sampler_kwargs" not in dataloader_args.keys():
Expand Down
98 changes: 40 additions & 58 deletions src/graphnet/data/dataset/samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
"""`Sampler` and `BatchSampler` objects for `graphnet`."""
"""`Sampler` and `BatchSampler` objects for `graphnet`.
MIT License
Copyright (c) 2023 DrHB
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
_____________________
"""

from typing import (
Any,
List,
Expand All @@ -21,30 +46,8 @@
class RandomChunkSampler(Sampler[int]):
"""A `Sampler` that randomly selects chunks.
MIT License
Copyright (c) 2023 DrHB
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
_____________________
Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py
Original implementation:
https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py
"""

def __init__(
Expand Down Expand Up @@ -157,30 +160,8 @@ def gather_len_matched_buckets(
class LenMatchBatchSampler(BatchSampler, Logger):
"""A `BatchSampler` that batches similar length events.
MIT License
Copyright (c) 2023 DrHB
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
_____________________
Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py
Original implementation:
https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py
"""

def __init__(
Expand All @@ -195,20 +176,21 @@ def __init__(
) -> None:
"""Construct `LenMatchBatchSampler`.
This `BatchSampler` groups data with similar lengths to be more efficient
in operations like masking for MultiHeadAttention. Since batch samplers
run on the main process and can result in a CPU bottleneck, `num_workers`
can be specified to use multiprocessing for creating the batches. The
`bucket_width` argument specifies how wide the bins are for grouping batches.
For example, with `bucket_width=16`, data with length [1, 16] are grouped into
a bucket, data with length [17, 32] into another, etc.
This `BatchSampler` groups data with similar lengths to be more
efficient in operations like masking for MultiHeadAttention. Since
batch samplers run on the main process and can result in a CPU
bottleneck, `num_workers` can be specified to use multiprocessing for
creating the batches. The `bucket_width` argument specifies how wide
the bins are for grouping batches. For example, with `bucket_width=16`,
data with length [1, 16] are grouped into a bucket, data with length
[17, 32] into another, etc.
Args:
sampler: A `Sampler` object that selects/draws data in some way.
batch_size: Batch size.
num_workers: Number of workers to spawn to create batches.
bucket_width: Size of length buckets for grouping data.
chunks_per_segment: Number of chunks to group together for processing.
chunks_per_segment: Number of chunks to group together.
multiprocessing_context: Start method for multiprocessing.
drop_last: (Optional) Drop the last incomplete batch.
"""
Expand Down Expand Up @@ -237,7 +219,7 @@ def __iter__(self) -> Iterator[List[int]]:
n_chunks = len(self.sampler.chunks)
n_segments = n_chunks // self._chunks_per_segment

# Split indices into nearly equal-sized segments amongst the workers
# Split indices into nearly equal-sized segments amongst workers
segments = [
range(
sum(self.sampler.chunks[: i * self._chunks_per_segment]),
Expand Down

0 comments on commit d6e7bed

Please sign in to comment.