Skip to content

Commit

Permalink
Jl/upgrade black (#645)
Browse files Browse the repository at this point in the history
* upgrade black

* formatting according to black
  • Loading branch information
jnsLs authored Jul 19, 2024
1 parent 4cff703 commit 9b61c9e
Show file tree
Hide file tree
Showing 40 changed files with 219 additions and 168 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
repos:
- repo: https://github.com/python/black
rev: 22.3.0
rev: 24.4.2
hooks:
- id: black
2 changes: 1 addition & 1 deletion src/schnetpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from schnetpack import md


__version__ = '2.0.4'
__version__ = "2.0.4"
13 changes: 9 additions & 4 deletions src/schnetpack/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train(config: DictConfig):
else:
# choose seed randomly
with open_dict(config):
config.seed = random.randint(0, 2 ** 32 - 1)
config.seed = random.randint(0, 2**32 - 1)
log.info(f"Seed randomly with <{config.seed}>")
seed_everything(seed=config.seed, workers=True)

Expand All @@ -112,7 +112,11 @@ def train(config: DictConfig):
log.info(f"Instantiating datamodule <{config.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(
config.data,
train_sampler_cls=str2class(config.data.train_sampler_cls) if config.data.train_sampler_cls else None,
train_sampler_cls=(
str2class(config.data.train_sampler_cls)
if config.data.train_sampler_cls
else None
),
)

# Init model
Expand Down Expand Up @@ -208,13 +212,14 @@ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
results = {k: v.detach().cpu() for k, v in results.items()}
return results


log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer,
callbacks=[
PredictionWriter(
output_dir=config.outputdir, write_interval=config.write_interval, write_idx=config.write_idx_m
output_dir=config.outputdir,
write_interval=config.write_interval,
write_idx=config.write_idx_m,
)
],
default_root_dir=".",
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/data/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
The atomic simulation environment -- a Python library for working with atoms.
Journal of Physics: Condensed Matter, 9, 27. 2017.
"""

import logging
import os
from abc import ABC, abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def train_dataloader(self) -> AtomsLoader:
train_batch_sampler = self._setup_sampler(
sampler_cls=self.train_sampler_cls,
sampler_args=self.train_sampler_args,
dataset=self._train_dataset
dataset=self._train_dataset,
)

self._train_dataloader = AtomsLoader(
Expand Down
4 changes: 2 additions & 2 deletions src/schnetpack/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
num_workers: int = 0,
collate_fn: _collate_fn_t = _atoms_collate_fn,
pin_memory: bool = False,
**kwargs
**kwargs,
):
super(AtomsLoader, self).__init__(
dataset=dataset,
Expand All @@ -82,5 +82,5 @@ def __init__(
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
**kwargs
**kwargs,
)
21 changes: 13 additions & 8 deletions src/schnetpack/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NumberOfAtomsCriterion:
"""
A callable class that returns the number of atoms for each sample in the dataset.
"""

def __call__(self, dataset):
n_atoms = []
for spl_idx in range(len(dataset)):
Expand All @@ -31,6 +32,7 @@ class PropertyCriterion:
A callable class that returns the specified property for each sample in the dataset.
Property must be a scalar value.
"""

def __init__(self, property_key: str = properties.energy):
self.property_key = property_key

Expand All @@ -48,14 +50,15 @@ class StratifiedSampler(WeightedRandomSampler):
Note: Make sure that num_bins is chosen sufficiently small to avoid too many empty bins.
"""

def __init__(
self,
data_source: BaseAtomsData,
partition_criterion: Callable[[BaseAtomsData], List],
num_samples: int,
num_bins: int = 10,
replacement: bool = True,
verbose: bool = True,
self,
data_source: BaseAtomsData,
partition_criterion: Callable[[BaseAtomsData], List],
num_samples: int,
num_bins: int = 10,
replacement: bool = True,
verbose: bool = True,
) -> None:
"""
Args:
Expand All @@ -72,7 +75,9 @@ def __init__(
self.verbose = verbose

weights = self.calculate_weights(partition_criterion)
super().__init__(weights=weights, num_samples=num_samples, replacement=replacement)
super().__init__(
weights=weights, num_samples=num_samples, replacement=replacement
)

def calculate_weights(self, partition_criterion):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/schnetpack/datasets/ani1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
num_test_workers: Optional[int] = None,
property_units: Optional[Dict[str, str]] = None,
distance_unit: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
num_test_workers=num_test_workers,
property_units=property_units,
distance_unit=distance_unit,
**kwargs
**kwargs,
)

def prepare_data(self):
Expand Down
13 changes: 9 additions & 4 deletions src/schnetpack/datasets/iso17.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
num_test_workers: Optional[int] = None,
property_units: Optional[Dict[str, str]] = None,
distance_unit: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
num_test_workers=num_test_workers,
property_units=property_units,
distance_unit=distance_unit,
**kwargs
**kwargs,
)

def prepare_data(self):
Expand Down Expand Up @@ -148,12 +148,17 @@ def _download_data(self):
with connect(dbpath) as conn:
with connect(tmp_dbpath) as tmp_conn:
tmp_conn.metadata = {
"_property_unit_dict": {ISO17.energy: "eV", ISO17.forces: "eV/Ang"},
"_property_unit_dict": {
ISO17.energy: "eV",
ISO17.forces: "eV/Ang",
},
"_distance_unit": "Ang",
"atomrefs": {},
}
# add energy to data dict in db
for idx in tqdm(range(len(conn)), f"parsing database file {dbpath}"):
for idx in tqdm(
range(len(conn)), f"parsing database file {dbpath}"
):
atmsrw = conn.get(idx + 1)
data = atmsrw.data
data[ISO17.forces] = np.array(data[ISO17.forces])
Expand Down
20 changes: 11 additions & 9 deletions src/schnetpack/datasets/materials_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
distance_unit: Optional[str] = None,
apikey: Optional[str] = None,
timestamp: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
num_test_workers=num_test_workers,
property_units=property_units,
distance_unit=distance_unit,
**kwargs
**kwargs,
)
if len(apikey) != 16:
raise AtomsDataModuleError(
Expand Down Expand Up @@ -197,13 +197,15 @@ def _download_data(self, dataset: BaseAtomsData):
)
properties_list.append(
{
MaterialsProject.EPerAtom: np.array([q["energy_per_atom"]]),
MaterialsProject.EformationPerAtom: np.array([q[
"formation_energy_per_atom"
]]),
MaterialsProject.TotalMagnetization: np.array([q[
"total_magnetization"
]]),
MaterialsProject.EPerAtom: np.array(
[q["energy_per_atom"]]
),
MaterialsProject.EformationPerAtom: np.array(
[q["formation_energy_per_atom"]]
),
MaterialsProject.TotalMagnetization: np.array(
[q["total_magnetization"]]
),
MaterialsProject.BandGap: np.array([q["band_gap"]]),
}
)
Expand Down
4 changes: 3 additions & 1 deletion src/schnetpack/datasets/md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def _download_data(
for positions, energies, forces in zip(data["R"], data["E"], data["F"]):
ats = Atoms(positions=positions, numbers=numbers)
properties = {
self.energy: energies if type(energies) is np.ndarray else np.array([energies]),
self.energy: (
energies if type(energies) is np.ndarray else np.array([energies])
),
self.forces: forces,
structure.Z: ats.numbers,
structure.R: ats.positions,
Expand Down
4 changes: 2 additions & 2 deletions src/schnetpack/datasets/omdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
property_units: Optional[Dict[str, str]] = None,
distance_unit: Optional[str] = None,
raw_path: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
num_test_workers=num_test_workers,
property_units=property_units,
distance_unit=distance_unit,
**kwargs
**kwargs,
)
self.raw_path = raw_path

Expand Down
4 changes: 2 additions & 2 deletions src/schnetpack/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
property_units: Optional[Dict[str, str]] = None,
distance_unit: Optional[str] = None,
data_workdir: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
property_units=property_units,
distance_unit=distance_unit,
data_workdir=data_workdir,
**kwargs
**kwargs,
)

self.remove_uncharacterized = remove_uncharacterized
Expand Down
4 changes: 3 additions & 1 deletion src/schnetpack/datasets/rmd17.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def __init__(
"""

if split_id is not None:
splitting = SubsamplePartitions(split_partition_sources=["known", "known", "test"], split_id=split_id)
splitting = SubsamplePartitions(
split_partition_sources=["known", "known", "test"], split_id=split_id
)
else:
splitting = RandomSplit()

Expand Down
43 changes: 21 additions & 22 deletions src/schnetpack/datasets/tmqm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class TMQM(AtomsDataModule):
"""tmQM database of Ballcells 2020 of inorganic CSD structures.
References:
Expand All @@ -41,7 +41,7 @@ class TMQM(AtomsDataModule):
# dipole moment, and natural charge of the metal center; GFN2-xTB polarizabilities are also provided.

# these strings match the names in the header of the csv file
csd_code = "CSD_code" #should go into key-value pair
csd_code = "CSD_code" # should go into key-value pair
energy = "Electronic_E"
dispersion = "Dispersion_E"
homo = "HOMO_Energy"
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
property_units: Optional[Dict[str, str]] = None,
distance_unit: Optional[str] = None,
data_workdir: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Expand Down Expand Up @@ -121,10 +121,9 @@ def __init__(
property_units=property_units,
distance_unit=distance_unit,
data_workdir=data_workdir,
**kwargs
**kwargs,
)


def prepare_data(self):
if not os.path.exists(self.datapath):
property_unit_dict = {
Expand Down Expand Up @@ -152,12 +151,12 @@ def prepare_data(self):
else:
dataset = load_dataset(self.datapath, self.format)

def _download_data(
self, tmpdir, dataset: BaseAtomsData
):
def _download_data(self, tmpdir, dataset: BaseAtomsData):
tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz")
url = ["https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz",
"https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X2.xyz.gz"]
url = [
"https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz",
"https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X2.xyz.gz",
]

url_y = "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_y.csv"

Expand All @@ -168,24 +167,23 @@ def _download_data(

for u in url:
request.urlretrieve(u, tar_path)
with gzip.open(tar_path, 'rb') as f_in:
with open(tmp_xyz_file, 'wb') as f_out:
with gzip.open(tar_path, "rb") as f_in:
with open(tmp_xyz_file, "wb") as f_out:
lines = f_in.readlines()
# remove empty lines
lines = [line for line in lines if line.strip()]
f_out.writelines(lines)

atomslist.extend(read(tmp_xyz_file, index=":"))

atomslist.extend(read(tmp_xyz_file, index=":"))

# download proeprties in tmQM_y.csv
request.urlretrieve(url_y, tmp_properties_file)

# CSV format
#CSD_code;Electronic_E;Dispersion_E;Dipole_M;Metal_q;HL_Gap;HOMO_Energy;LUMO_Energy;Polarizability
#WIXKOE;-2045.524942;-0.239239;4.233300;2.109340;0.131080;-0.162040;-0.030960;598.457913
#DUCVIG;-2430.690317;-0.082134;11.754400;0.759940;0.124930;-0.243580;-0.118650;277.750698
#KINJOG;-3467.923206;-0.137954;8.301700;1.766500;0.140140;-0.236460;-0.096320;393.442545
# CSD_code;Electronic_E;Dispersion_E;Dipole_M;Metal_q;HL_Gap;HOMO_Energy;LUMO_Energy;Polarizability
# WIXKOE;-2045.524942;-0.239239;4.233300;2.109340;0.131080;-0.162040;-0.030960;598.457913
# DUCVIG;-2430.690317;-0.082134;11.754400;0.759940;0.124930;-0.243580;-0.118650;277.750698
# KINJOG;-3467.923206;-0.137954;8.301700;1.766500;0.140140;-0.236460;-0.096320;393.442545

# read csv
prop_list = []
Expand All @@ -194,13 +192,14 @@ def _download_data(
with open(tmp_properties_file, "r") as file:
lines = file.readlines()
keys = lines[0].strip("\n").split(";")

for l in lines[1:]:
properties = l.split(";")
prop_dict = {k:np.array([float(v)]) for k, v in zip(keys[1:], properties[1:])}
key_value_pairs = {k:v for k, v in zip(keys[0], properties[0])}
prop_dict = {
k: np.array([float(v)]) for k, v in zip(keys[1:], properties[1:])
}
key_value_pairs = {k: v for k, v in zip(keys[0], properties[0])}
prop_list.append(prop_dict)
key_value_pairs_list.append(key_value_pairs)


dataset.add_systems(property_list=prop_list, atoms_list=atomslist)
Loading

0 comments on commit 9b61c9e

Please sign in to comment.