Skip to content

Commit

Permalink
Merge pull request #742 from pweigel/parquet_dataset_improvements
Browse files Browse the repository at this point in the history
Improvements to parquet dataloading, sampling, batch sampling
  • Loading branch information
pweigel authored Sep 11, 2024
2 parents dc7fa4f + d6e7bed commit e8140c5
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 12 deletions.
34 changes: 33 additions & 1 deletion src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,39 @@ def _create_dataloader(
"Unknown dataset encountered during dataloader creation."
)

if "sampler" in dataloader_args.keys():
# If there were no kwargs provided, set it to empty dict
if "sampler_kwargs" not in dataloader_args.keys():
dataloader_args["sampler_kwargs"] = {}
dataloader_args["sampler"] = dataloader_args["sampler"](
dataset, **dataloader_args["sampler_kwargs"]
)
del dataloader_args["sampler_kwargs"]

if "batch_sampler" in dataloader_args.keys():
if "sampler" not in dataloader_args.keys():
raise KeyError(
"When specifying a `batch_sampler`,"
"you must also provide `sampler`."
)
# If there were no kwargs provided, set it to empty dict
if "batch_sampler_kwargs" not in dataloader_args.keys():
dataloader_args["batch_sampler_kwargs"] = {}

batch_sampler = dataloader_args["batch_sampler"](
dataloader_args["sampler"],
**dataloader_args["batch_sampler_kwargs"],
)
dataloader_args["batch_sampler"] = batch_sampler
# Remove extra keys
for key in [
"batch_sampler_kwargs",
"drop_last",
"sampler",
"shuffle",
]:
dataloader_args.pop(key, None)

if dataloader_args is None:
raise AttributeError("Dataloader arguments not provided.")

Expand Down Expand Up @@ -479,7 +512,6 @@ def _infer_selections_on_single_dataset(
.sample(frac=1, replace=False, random_state=self._rng)
.values.tolist()
) # shuffled list

return self._split_selection(all_events)

def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset:
Expand Down
4 changes: 4 additions & 0 deletions src/graphnet/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
if has_torch_package():
import torch.multiprocessing
from .dataset import EnsembleDataset, Dataset, ColumnMissingException
from .samplers import (
RandomChunkSampler,
LenMatchBatchSampler,
)
from .parquet.parquet_dataset import ParquetDataset
from .sqlite.sqlite_dataset import SQLiteDataset

Expand Down
27 changes: 16 additions & 11 deletions src/graphnet/data/dataset/parquet/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
events ~ event_no % 5 > 0"`).
graph_definition: Method that defines the graph representation.
cache_size: Number of batches to cache in memory.
cache_size: Number of files to cache in memory.
Must be at least 1. Defaults to 1.
labels: Dictionary of labels to be added to the dataset.
"""
Expand Down Expand Up @@ -124,8 +124,8 @@ def __init__(
self._path: str = self._path
# Member Variables
self._cache_size = cache_size
self._batch_sizes = self._calculate_sizes()
self._batch_cumsum = np.cumsum(self._batch_sizes)
self._chunk_sizes = self._calculate_sizes()
self._chunk_cumsum = np.cumsum(self._chunk_sizes)
self._file_cache = self._initialize_file_cache(
truth_table=truth_table,
node_truth_table=node_truth_table,
Expand Down Expand Up @@ -180,32 +180,37 @@ def _get_event_index(self, sequential_index: int) -> int:
)
return event_index

@property
def chunk_sizes(self) -> List[int]:
"""Return a list of the chunk sizes."""
return self._chunk_sizes

def __len__(self) -> int:
"""Return length of dataset, i.e. number of training examples."""
return sum(self._batch_sizes)
return sum(self._chunk_sizes)

def _get_all_indices(self) -> List[int]:
"""Return a list of all unique values in `self._index_column`."""
files = glob(os.path.join(self._path, self._truth_table, "*.parquet"))
return np.arange(0, len(files), 1)

def _calculate_sizes(self) -> List[int]:
"""Calculate the number of events in each batch."""
"""Calculate the number of events in each chunk."""
sizes = []
for batch_id in self._indices:
for chunk_id in self._indices:
path = os.path.join(
self._path,
self._truth_table,
f"{self.truth_table}_{batch_id}.parquet",
f"{self.truth_table}_{chunk_id}.parquet",
)
sizes.append(len(pol.read_parquet(path)))
return sizes

def _get_row_idx(self, sequential_index: int) -> int:
"""Return the row index corresponding to a `sequential_index`."""
file_idx = bisect_right(self._batch_cumsum, sequential_index)
file_idx = bisect_right(self._chunk_cumsum, sequential_index)
if file_idx > 0:
idx = int(sequential_index - self._batch_cumsum[file_idx - 1])
idx = int(sequential_index - self._chunk_cumsum[file_idx - 1])
else:
idx = sequential_index
return idx
Expand Down Expand Up @@ -242,9 +247,9 @@ def query_table( # type: ignore
columns = [columns]

if sequential_index is None:
file_idx = np.arange(0, len(self._batch_cumsum), 1)
file_idx = np.arange(0, len(self._chunk_cumsum), 1)
else:
file_idx = [bisect_right(self._batch_cumsum, sequential_index)]
file_idx = [bisect_right(self._chunk_cumsum, sequential_index)]

file_indices = [self._indices[idx] for idx in file_idx]

Expand Down
Loading

0 comments on commit e8140c5

Please sign in to comment.