Skip to content

Commit

Permalink
make subclassing friendlier
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed May 30, 2022
1 parent 722755f commit f766579
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 31 deletions.
116 changes: 116 additions & 0 deletions examples/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Dict, List, Callable, Union, Optional
import numpy as np
import logging

import torch

from nequip.data import AtomicData
from nequip.utils.savenload import atomic_write
from nequip.data.transforms import TypeMapper
from nequip.data import AtomicDataset


class ExampleCustomDataset(AtomicDataset):
"""
See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets.
If you don't need downloading or pre-processing, just don't define any of the relevant methods/properties.
"""

def __init__(
self,
root: str,
custom_option1,
custom_option2="default",
type_mapper: Optional[TypeMapper] = None,
):
# Initialize the AtomicDataset, which runs .download() (if present) and .process()
# See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets
# This will only run download and preprocessing if cached dataset files aren't found
super().__init__(root=root, type_mapper=type_mapper)

# if the processed paths don't exist, `self.process()` has been called at this point
# (if it is defined)
# but otherwise you need to load the data from the cached pre-processed dir:
if self.mydata is None:
self.mydata = torch.load(self.processed_paths[0])
# if you didn't define `process()`, this is where you would unconditionally load your data.

def len(self) -> int:
"""Return the number of frames in the dataset."""
return 42

@property
def raw_file_names(self) -> List[str]:
"""Return a list of filenames for the raw data.
Need to be simple filenames to be looked for in `self.raw_dir`
"""
return ["data.dat"]

@property
def raw_dir(self) -> str:
return "/path/to/dataset-folder/"

@property
def processed_file_names(self) -> List[str]:
"""Like `self.raw_file_names`, but for the files generated by `self.process()`.
Should not be paths, just file names. These will be stored in `self.processed_dir`,
which is set by NequIP in `AtomicDataset` based on `self.root` and a hash of the
dataset options provided to `__init__`.
"""
return ["processed-data.pth"]

# def download(self):
# """Optional method to download raw data before preprocessing if the `raw_paths` do not exist."""
# pass

def process(self):
# load things from the raw data:
# whatever is appropriate for your format
data = np.load(self.raw_dir + "/" + self.raw_file_names[0])

# if any pre-processing is necessary, do it and cache the results to
# `self.processed_paths` as you defined above:
with atomic_write(self.processed_paths[0], binary=True) as f:
# e.g., anything that takes a file `f` will work
torch.save(data, f)
# ^ use atomic writes to avoid race conditions between
# different trainings that use the same dataset
# since those separate trainings should all produce the same results,
# it doesn't matter if they overwrite each others cached'
# datasets. It only matters that they don't simultaneously try
# to write the _same_ file, corrupting it.

logging.info("Cached processed data to disk")

# optionally, save the processed data on the Dataset object
# to avoid a roundtrip from disk in `__init__` (see above)
self.mydata = data

def get(self, idx: int) -> AtomicData:
"""Return the data frame with a given index as an `AtomicData` object."""
build_an_AtomicData_here = None
return build_an_AtomicData_here

def statistics(
self,
fields: List[Union[str, Callable]],
modes: List[str],
stride: int = 1,
unbiased: bool = True,
kwargs: Optional[Dict[str, dict]] = {},
) -> List[tuple]:
"""Optional method to compute statistics over an entire dataset.
This must correctly handle `self._indices` for subsets!!!
If not provided, options like `avg_num_neighbors: auto`, `per_species_rescale_scales: dataset_*`,
and others that compute dataset statistics will not work. This only needs to support the statistics
modes that are necessary for what you need to run (i.e. if you do not use `dataset_per_species_*`
statistics, you do not need to implement them).
See `AtomicInMemoryDataset` for full documentation and example implementation.
"""
raise NotImplementedError
62 changes: 31 additions & 31 deletions nequip/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ def type_mapper(self) -> Optional[TypeMapper]:
# self.transform is always a TypeMapper
return self.transform

def _get_parameters(self) -> Dict[str, Any]:
"""Get a dict of the parameters used to build this dataset."""
pnames = list(inspect.signature(self.__init__).parameters)
IGNORE_KEYS = {
# the type mapper is applied after saving, not before, so doesn't matter to cache validity
"type_mapper"
}
params = {
k: getattr(self, k)
for k in pnames
if k not in IGNORE_KEYS and hasattr(self, k)
}
# Add other relevant metadata:
params["dtype"] = str(torch.get_default_dtype())
params["nequip_version"] = nequip.__version__
return params

@property
def processed_dir(self) -> str:
# We want the file name to change when the parameters change
# So, first we get all parameters:
params = self._get_parameters()
# Make some kind of string of them:
# we don't care about this possibly changing between python versions,
# since a change in python version almost certainly means a change in
# versions of other things too, and is a good reason to recompute
buffer = yaml.dump(params).encode("ascii")
# And hash it:
param_hash = hashlib.sha1(buffer).hexdigest()
return f"{self.root}/processed_dataset_{param_hash}"


class AtomicInMemoryDataset(AtomicDataset):
r"""Base class for all datasets that fit in memory.
Expand Down Expand Up @@ -152,37 +183,6 @@ def len(self):
def raw_file_names(self):
raise NotImplementedError()

def _get_parameters(self) -> Dict[str, Any]:
"""Get a dict of the parameters used to build this dataset."""
pnames = list(inspect.signature(self.__init__).parameters)
IGNORE_KEYS = {
# the type mapper is applied after saving, not before, so doesn't matter to cache validity
"type_mapper"
}
params = {
k: getattr(self, k)
for k in pnames
if k not in IGNORE_KEYS and hasattr(self, k)
}
# Add other relevant metadata:
params["dtype"] = str(torch.get_default_dtype())
params["nequip_version"] = nequip.__version__
return params

@property
def processed_dir(self) -> str:
# We want the file name to change when the parameters change
# So, first we get all parameters:
params = self._get_parameters()
# Make some kind of string of them:
# we don't care about this possibly changing between python versions,
# since a change in python version almost certainly means a change in
# versions of other things too, and is a good reason to recompute
buffer = yaml.dump(params).encode("ascii")
# And hash it:
param_hash = hashlib.sha1(buffer).hexdigest()
return f"{self.root}/processed_dataset_{param_hash}"

@property
def processed_file_names(self) -> List[str]:
return ["data.pth", "params.yaml"]
Expand Down

0 comments on commit f766579

Please sign in to comment.