diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e21f83820..b63272e1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ repos: - repo: https://github.com/python/black - rev: 22.3.0 + rev: 24.4.2 hooks: - id: black diff --git a/src/schnetpack/__init__.py b/src/schnetpack/__init__.py index 725f476d4..a3d808fce 100644 --- a/src/schnetpack/__init__.py +++ b/src/schnetpack/__init__.py @@ -17,4 +17,4 @@ from schnetpack import md -__version__ = '2.0.4' +__version__ = "2.0.4" diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 645610b98..6f986f7ea 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -98,7 +98,7 @@ def train(config: DictConfig): else: # choose seed randomly with open_dict(config): - config.seed = random.randint(0, 2 ** 32 - 1) + config.seed = random.randint(0, 2**32 - 1) log.info(f"Seed randomly with <{config.seed}>") seed_everything(seed=config.seed, workers=True) @@ -112,7 +112,11 @@ def train(config: DictConfig): log.info(f"Instantiating datamodule <{config.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate( config.data, - train_sampler_cls=str2class(config.data.train_sampler_cls) if config.data.train_sampler_cls else None, + train_sampler_cls=( + str2class(config.data.train_sampler_cls) + if config.data.train_sampler_cls + else None + ), ) # Init model @@ -208,13 +212,14 @@ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): results = {k: v.detach().cpu() for k, v in results.items()} return results - log.info(f"Instantiating trainer <{config.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( config.trainer, callbacks=[ PredictionWriter( - output_dir=config.outputdir, write_interval=config.write_interval, write_idx=config.write_idx_m + output_dir=config.outputdir, + write_interval=config.write_interval, + write_idx=config.write_idx_m, ) ], default_root_dir=".", diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 173a4cc08..c4c6975be 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -10,6 +10,7 @@ The atomic simulation environment -- a Python library for working with atoms. Journal of Physics: Condensed Matter, 9, 27. 2017. """ + import logging import os from abc import ABC, abstractmethod diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index d70f6ce2b..4484dc5e8 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -377,7 +377,7 @@ def train_dataloader(self) -> AtomsLoader: train_batch_sampler = self._setup_sampler( sampler_cls=self.train_sampler_cls, sampler_args=self.train_sampler_args, - dataset=self._train_dataset + dataset=self._train_dataset, ) self._train_dataloader = AtomsLoader( diff --git a/src/schnetpack/data/loader.py b/src/schnetpack/data/loader.py index 84de7fac4..ffded9f5f 100644 --- a/src/schnetpack/data/loader.py +++ b/src/schnetpack/data/loader.py @@ -71,7 +71,7 @@ def __init__( num_workers: int = 0, collate_fn: _collate_fn_t = _atoms_collate_fn, pin_memory: bool = False, - **kwargs + **kwargs, ): super(AtomsLoader, self).__init__( dataset=dataset, @@ -82,5 +82,5 @@ def __init__( num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, - **kwargs + **kwargs, ) diff --git a/src/schnetpack/data/sampler.py b/src/schnetpack/data/sampler.py index 00994f345..0e353ef88 100644 --- a/src/schnetpack/data/sampler.py +++ b/src/schnetpack/data/sampler.py @@ -18,6 +18,7 @@ class NumberOfAtomsCriterion: """ A callable class that returns the number of atoms for each sample in the dataset. """ + def __call__(self, dataset): n_atoms = [] for spl_idx in range(len(dataset)): @@ -31,6 +32,7 @@ class PropertyCriterion: A callable class that returns the specified property for each sample in the dataset. Property must be a scalar value. """ + def __init__(self, property_key: str = properties.energy): self.property_key = property_key @@ -48,14 +50,15 @@ class StratifiedSampler(WeightedRandomSampler): Note: Make sure that num_bins is chosen sufficiently small to avoid too many empty bins. """ + def __init__( - self, - data_source: BaseAtomsData, - partition_criterion: Callable[[BaseAtomsData], List], - num_samples: int, - num_bins: int = 10, - replacement: bool = True, - verbose: bool = True, + self, + data_source: BaseAtomsData, + partition_criterion: Callable[[BaseAtomsData], List], + num_samples: int, + num_bins: int = 10, + replacement: bool = True, + verbose: bool = True, ) -> None: """ Args: @@ -72,7 +75,9 @@ def __init__( self.verbose = verbose weights = self.calculate_weights(partition_criterion) - super().__init__(weights=weights, num_samples=num_samples, replacement=replacement) + super().__init__( + weights=weights, num_samples=num_samples, replacement=replacement + ) def calculate_weights(self, partition_criterion): """ diff --git a/src/schnetpack/datasets/ani1.py b/src/schnetpack/datasets/ani1.py index 37ce2b007..48fb0f062 100644 --- a/src/schnetpack/datasets/ani1.py +++ b/src/schnetpack/datasets/ani1.py @@ -61,7 +61,7 @@ def __init__( num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - **kwargs + **kwargs, ): """ @@ -112,7 +112,7 @@ def __init__( num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, - **kwargs + **kwargs, ) def prepare_data(self): diff --git a/src/schnetpack/datasets/iso17.py b/src/schnetpack/datasets/iso17.py index 76d5aa094..cfc8bb391 100644 --- a/src/schnetpack/datasets/iso17.py +++ b/src/schnetpack/datasets/iso17.py @@ -62,7 +62,7 @@ def __init__( num_test_workers: Optional[int] = None, property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, - **kwargs + **kwargs, ): """ Args: @@ -113,7 +113,7 @@ def __init__( num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, - **kwargs + **kwargs, ) def prepare_data(self): @@ -148,12 +148,17 @@ def _download_data(self): with connect(dbpath) as conn: with connect(tmp_dbpath) as tmp_conn: tmp_conn.metadata = { - "_property_unit_dict": {ISO17.energy: "eV", ISO17.forces: "eV/Ang"}, + "_property_unit_dict": { + ISO17.energy: "eV", + ISO17.forces: "eV/Ang", + }, "_distance_unit": "Ang", "atomrefs": {}, } # add energy to data dict in db - for idx in tqdm(range(len(conn)), f"parsing database file {dbpath}"): + for idx in tqdm( + range(len(conn)), f"parsing database file {dbpath}" + ): atmsrw = conn.get(idx + 1) data = atmsrw.data data[ISO17.forces] = np.array(data[ISO17.forces]) diff --git a/src/schnetpack/datasets/materials_project.py b/src/schnetpack/datasets/materials_project.py index a7c203df6..fa9cd0622 100644 --- a/src/schnetpack/datasets/materials_project.py +++ b/src/schnetpack/datasets/materials_project.py @@ -54,7 +54,7 @@ def __init__( distance_unit: Optional[str] = None, apikey: Optional[str] = None, timestamp: Optional[str] = None, - **kwargs + **kwargs, ): """ @@ -101,7 +101,7 @@ def __init__( num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, - **kwargs + **kwargs, ) if len(apikey) != 16: raise AtomsDataModuleError( @@ -197,13 +197,15 @@ def _download_data(self, dataset: BaseAtomsData): ) properties_list.append( { - MaterialsProject.EPerAtom: np.array([q["energy_per_atom"]]), - MaterialsProject.EformationPerAtom: np.array([q[ - "formation_energy_per_atom" - ]]), - MaterialsProject.TotalMagnetization: np.array([q[ - "total_magnetization" - ]]), + MaterialsProject.EPerAtom: np.array( + [q["energy_per_atom"]] + ), + MaterialsProject.EformationPerAtom: np.array( + [q["formation_energy_per_atom"]] + ), + MaterialsProject.TotalMagnetization: np.array( + [q["total_magnetization"]] + ), MaterialsProject.BandGap: np.array([q["band_gap"]]), } ) diff --git a/src/schnetpack/datasets/md17.py b/src/schnetpack/datasets/md17.py index 46249ee51..c79545aef 100644 --- a/src/schnetpack/datasets/md17.py +++ b/src/schnetpack/datasets/md17.py @@ -170,7 +170,9 @@ def _download_data( for positions, energies, forces in zip(data["R"], data["E"], data["F"]): ats = Atoms(positions=positions, numbers=numbers) properties = { - self.energy: energies if type(energies) is np.ndarray else np.array([energies]), + self.energy: ( + energies if type(energies) is np.ndarray else np.array([energies]) + ), self.forces: forces, structure.Z: ats.numbers, structure.R: ats.positions, diff --git a/src/schnetpack/datasets/omdb.py b/src/schnetpack/datasets/omdb.py index ee025393a..413ef5ba0 100644 --- a/src/schnetpack/datasets/omdb.py +++ b/src/schnetpack/datasets/omdb.py @@ -51,7 +51,7 @@ def __init__( property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, raw_path: Optional[str] = None, - **kwargs + **kwargs, ): """ Args: @@ -96,7 +96,7 @@ def __init__( num_test_workers=num_test_workers, property_units=property_units, distance_unit=distance_unit, - **kwargs + **kwargs, ) self.raw_path = raw_path diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index a36c31caf..b2c16fe60 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -74,7 +74,7 @@ def __init__( property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, data_workdir: Optional[str] = None, - **kwargs + **kwargs, ): """ @@ -122,7 +122,7 @@ def __init__( property_units=property_units, distance_unit=distance_unit, data_workdir=data_workdir, - **kwargs + **kwargs, ) self.remove_uncharacterized = remove_uncharacterized diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 0dae33b74..01feac44f 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -118,7 +118,9 @@ def __init__( """ if split_id is not None: - splitting = SubsamplePartitions(split_partition_sources=["known", "known", "test"], split_id=split_id) + splitting = SubsamplePartitions( + split_partition_sources=["known", "known", "test"], split_id=split_id + ) else: splitting = RandomSplit() diff --git a/src/schnetpack/datasets/tmqm.py b/src/schnetpack/datasets/tmqm.py index 07d3a7eb6..17c856ec7 100644 --- a/src/schnetpack/datasets/tmqm.py +++ b/src/schnetpack/datasets/tmqm.py @@ -26,7 +26,7 @@ class TMQM(AtomsDataModule): """tmQM database of Ballcells 2020 of inorganic CSD structures. - + References: @@ -41,7 +41,7 @@ class TMQM(AtomsDataModule): # dipole moment, and natural charge of the metal center; GFN2-xTB polarizabilities are also provided. # these strings match the names in the header of the csv file - csd_code = "CSD_code" #should go into key-value pair + csd_code = "CSD_code" # should go into key-value pair energy = "Electronic_E" dispersion = "Dispersion_E" homo = "HOMO_Energy" @@ -73,7 +73,7 @@ def __init__( property_units: Optional[Dict[str, str]] = None, distance_unit: Optional[str] = None, data_workdir: Optional[str] = None, - **kwargs + **kwargs, ): """ @@ -121,10 +121,9 @@ def __init__( property_units=property_units, distance_unit=distance_unit, data_workdir=data_workdir, - **kwargs + **kwargs, ) - def prepare_data(self): if not os.path.exists(self.datapath): property_unit_dict = { @@ -152,12 +151,12 @@ def prepare_data(self): else: dataset = load_dataset(self.datapath, self.format) - def _download_data( - self, tmpdir, dataset: BaseAtomsData - ): + def _download_data(self, tmpdir, dataset: BaseAtomsData): tar_path = os.path.join(tmpdir, "tmQM_X1.xyz.gz") - url = ["https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz", - "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X2.xyz.gz"] + url = [ + "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X1.xyz.gz", + "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_X2.xyz.gz", + ] url_y = "https://github.com/bbskjelstad/tmqm/raw/master/data/tmQM_y.csv" @@ -168,24 +167,23 @@ def _download_data( for u in url: request.urlretrieve(u, tar_path) - with gzip.open(tar_path, 'rb') as f_in: - with open(tmp_xyz_file, 'wb') as f_out: + with gzip.open(tar_path, "rb") as f_in: + with open(tmp_xyz_file, "wb") as f_out: lines = f_in.readlines() # remove empty lines lines = [line for line in lines if line.strip()] f_out.writelines(lines) - - atomslist.extend(read(tmp_xyz_file, index=":")) + atomslist.extend(read(tmp_xyz_file, index=":")) # download proeprties in tmQM_y.csv request.urlretrieve(url_y, tmp_properties_file) # CSV format - #CSD_code;Electronic_E;Dispersion_E;Dipole_M;Metal_q;HL_Gap;HOMO_Energy;LUMO_Energy;Polarizability - #WIXKOE;-2045.524942;-0.239239;4.233300;2.109340;0.131080;-0.162040;-0.030960;598.457913 - #DUCVIG;-2430.690317;-0.082134;11.754400;0.759940;0.124930;-0.243580;-0.118650;277.750698 - #KINJOG;-3467.923206;-0.137954;8.301700;1.766500;0.140140;-0.236460;-0.096320;393.442545 + # CSD_code;Electronic_E;Dispersion_E;Dipole_M;Metal_q;HL_Gap;HOMO_Energy;LUMO_Energy;Polarizability + # WIXKOE;-2045.524942;-0.239239;4.233300;2.109340;0.131080;-0.162040;-0.030960;598.457913 + # DUCVIG;-2430.690317;-0.082134;11.754400;0.759940;0.124930;-0.243580;-0.118650;277.750698 + # KINJOG;-3467.923206;-0.137954;8.301700;1.766500;0.140140;-0.236460;-0.096320;393.442545 # read csv prop_list = [] @@ -194,13 +192,14 @@ def _download_data( with open(tmp_properties_file, "r") as file: lines = file.readlines() keys = lines[0].strip("\n").split(";") - + for l in lines[1:]: properties = l.split(";") - prop_dict = {k:np.array([float(v)]) for k, v in zip(keys[1:], properties[1:])} - key_value_pairs = {k:v for k, v in zip(keys[0], properties[0])} + prop_dict = { + k: np.array([float(v)]) for k, v in zip(keys[1:], properties[1:]) + } + key_value_pairs = {k: v for k, v in zip(keys[0], properties[0])} prop_list.append(prop_dict) key_value_pairs_list.append(key_value_pairs) - dataset.add_systems(property_list=prop_list, atoms_list=atomslist) diff --git a/src/schnetpack/interfaces/ase_interface.py b/src/schnetpack/interfaces/ase_interface.py index 4c8dc0a74..d5fdfe79a 100644 --- a/src/schnetpack/interfaces/ase_interface.py +++ b/src/schnetpack/interfaces/ase_interface.py @@ -193,7 +193,6 @@ def __init__( additional_inputs: Dict[str, torch.Tensor] = None, **kwargs, ): - """ Args: model_file (str): path to trained model diff --git a/src/schnetpack/interfaces/batchwise_optimization.py b/src/schnetpack/interfaces/batchwise_optimization.py index 1bc8842f3..c37193550 100644 --- a/src/schnetpack/interfaces/batchwise_optimization.py +++ b/src/schnetpack/interfaces/batchwise_optimization.py @@ -21,7 +21,12 @@ from schnetpack.interfaces.ase_interface import AtomsConverter -__all__ = ["ASEBatchwiseLBFGS", "BatchwiseCalculator", "BatchwiseEnsembleCalculator", "NNEnsemble"] +__all__ = [ + "ASEBatchwiseLBFGS", + "BatchwiseCalculator", + "BatchwiseEnsembleCalculator", + "NNEnsemble", +] class AtomsConverterError(Exception): @@ -141,7 +146,9 @@ def __init__( self.force_key: self.energy_conversion / self.position_conversion, } if self.stress_key is not None: - self.property_units[self.stress_key] = self.energy_conversion / self.position_conversion ** 3 + self.property_units[self.stress_key] = ( + self.energy_conversion / self.position_conversion**3 + ) # load model from path if needed if type(model) == str: @@ -170,14 +177,18 @@ def _requires_calculation(self, property_keys: List[str], atoms: List[ase.Atoms] if atom != atom_ref: return True - def get_forces(self, atoms: List[ase.Atoms], fixed_atoms_mask: Optional[List[int]] = None) -> np.array: + def get_forces( + self, atoms: List[ase.Atoms], fixed_atoms_mask: Optional[List[int]] = None + ) -> np.array: """ atoms: fixed_atoms_mask: list of indices corresponding to atoms with positions fixed in space. """ - if self._requires_calculation(property_keys=[self.energy_key, self.force_key], atoms=atoms): + if self._requires_calculation( + property_keys=[self.energy_key, self.force_key], atoms=atoms + ): self.calculate(atoms) f = self.results[self.force_key] if fixed_atoms_mask is not None: @@ -217,6 +228,7 @@ class BatchwiseEnsembleCalculator(BatchwiseCalculator): """ Calculator for ensemble of neural network models for batchwise optimization. """ + # TODO: inherit from SpkEnsembleCalculator def __init__( self, @@ -278,16 +290,14 @@ def __init__( def _load_model(self, model: str) -> nn.ModuleList: # get model paths model_names = os.listdir(model) - model_paths = [ - os.path.join(model, model_name) for model_name in model_names - ] + model_paths = [os.path.join(model, model_name) for model_name in model_names] # create module list models = torch.nn.ModuleList() for m_path in model_paths: - m = torch.load( - os.path.join(m_path, "best_model"), map_location="cpu" - ).to(torch.float64) + m = torch.load(os.path.join(m_path, "best_model"), map_location="cpu").to( + torch.float64 + ) models.append(m) return models @@ -299,9 +309,7 @@ def _initialize_model(self, model: nn.ModuleList) -> None: m.output_modules.insert(1, auxiliary_output_module) # initialize ensemble - ensemble = NNEnsemble( - models=model, properties=list(self.property_units.keys()) - ) + ensemble = NNEnsemble(models=model, properties=list(self.property_units.keys())) self.model = ensemble.eval().to(device=self.device, dtype=self.dtype) def calculate(self, atoms: List[ase.Atoms]) -> None: @@ -314,8 +322,7 @@ def calculate(self, atoms: List[ase.Atoms]) -> None: for prop in property_keys: if prop in model_results: results["{}_uncertainty".format(prop)] = ( - stds[prop].detach().cpu().numpy() - * self.property_units[prop] + stds[prop].detach().cpu().numpy() * self.property_units[prop] ) # store model results in calculator @@ -348,7 +355,7 @@ def __init__( append_trajectory: bool = False, master: Optional[bool] = None, log_every_step: bool = False, - fixed_atoms_mask: Optional[List[int]]=None, + fixed_atoms_mask: Optional[List[int]] = None, ): """Structure dynamics object. @@ -626,7 +633,6 @@ def __init__( fixed_atoms_mask: Optional[List[int]] = None, verbose: bool = False, ): - """Parameters: calculator: diff --git a/src/schnetpack/md/__init__.py b/src/schnetpack/md/__init__.py index 38b36d885..35d1391f8 100644 --- a/src/schnetpack/md/__init__.py +++ b/src/schnetpack/md/__init__.py @@ -2,6 +2,7 @@ This module contains all functionality for performing various molecular dynamics simulations using SchNetPack. """ + from .system import * from .initial_conditions import * from .simulator import * diff --git a/src/schnetpack/md/data/spectra.py b/src/schnetpack/md/data/spectra.py index 980e6a8ec..9c66ef1e6 100644 --- a/src/schnetpack/md/data/spectra.py +++ b/src/schnetpack/md/data/spectra.py @@ -9,6 +9,7 @@ Computing vibrational spectra from ab initio molecular dynamics. Phys. Chem. Chem. Phys., 15 (18), 6608--6622. 2013. """ + import numpy as np from ase import units as ase_units from schnetpack.md.data import HDF5Loader diff --git a/src/schnetpack/md/initial_conditions.py b/src/schnetpack/md/initial_conditions.py index 5439e7128..8e4122377 100644 --- a/src/schnetpack/md/initial_conditions.py +++ b/src/schnetpack/md/initial_conditions.py @@ -2,6 +2,7 @@ Module for setting up the initial conditions of the molecules in :obj:`schnetpack.md.System`. This entails sampling the momenta from random distributions corresponding to certain temperatures. """ + import torch from schnetpack.md import System from schnetpack import units as spk_units diff --git a/src/schnetpack/md/integrators.py b/src/schnetpack/md/integrators.py index 3bbe1f4f1..c47dd028d 100644 --- a/src/schnetpack/md/integrators.py +++ b/src/schnetpack/md/integrators.py @@ -5,6 +5,7 @@ integrator simulates multiple replicas of the system coupled by harmonic springs and recovers a certain extent of nuclear quantum effects (e.g. tunneling). """ + import torch import torch.nn as nn import numpy as np diff --git a/src/schnetpack/md/simulation_hooks/barostats.py b/src/schnetpack/md/simulation_hooks/barostats.py index c7668317f..3440418b3 100644 --- a/src/schnetpack/md/simulation_hooks/barostats.py +++ b/src/schnetpack/md/simulation_hooks/barostats.py @@ -2,6 +2,7 @@ This module contains various barostats for controlling the pressure of the system during molecular dynamics simulations. """ + from __future__ import annotations from typing import Optional, Tuple, TYPE_CHECKING @@ -357,16 +358,12 @@ def _init_thermostat_variables( # Get masses of innermost thermostat self.t_masses[..., 0] = ( - self.degrees_of_freedom_particles - * self.kb_temperature - / self.frequency**2 + self.degrees_of_freedom_particles * self.kb_temperature / self.frequency**2 ) # Get masses of cell self.t_masses_cell[..., 0] = ( - self.degrees_of_freedom_cell - * self.kb_temperature - / self.cell_frequency**2 + self.degrees_of_freedom_cell * self.kb_temperature / self.cell_frequency**2 ) # Set masses of remaining thermostats diff --git a/src/schnetpack/md/simulation_hooks/barostats_rpmd.py b/src/schnetpack/md/simulation_hooks/barostats_rpmd.py index e5d874b61..630d67689 100644 --- a/src/schnetpack/md/simulation_hooks/barostats_rpmd.py +++ b/src/schnetpack/md/simulation_hooks/barostats_rpmd.py @@ -2,6 +2,7 @@ This module contains barostats for controlling the pressure of the system during ring polymer molecular dynamics simulations. """ + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/src/schnetpack/md/simulation_hooks/callback_hooks.py b/src/schnetpack/md/simulation_hooks/callback_hooks.py index 9f913aa46..be665dabf 100644 --- a/src/schnetpack/md/simulation_hooks/callback_hooks.py +++ b/src/schnetpack/md/simulation_hooks/callback_hooks.py @@ -1,6 +1,7 @@ """ This module contains different hooks for monitoring the simulation and checkpointing. """ + from __future__ import annotations from typing import Union, List, Dict, Tuple, Any from typing import TYPE_CHECKING @@ -270,40 +271,40 @@ def update_buffer(self, buffer_position: int, simulator: Simulator): # Store energies start = 0 stop = simulator.system.n_molecules - self.buffer[ - buffer_position : buffer_position + 1, :, start:stop - ] = simulator.system.energy.view(simulator.system.n_replicas, -1).detach() + self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( + simulator.system.energy.view(simulator.system.n_replicas, -1).detach() + ) # Store positions start = stop stop += simulator.system.total_n_atoms * 3 - self.buffer[ - buffer_position : buffer_position + 1, :, start:stop - ] = simulator.system.positions.view(simulator.system.n_replicas, -1).detach() + self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( + simulator.system.positions.view(simulator.system.n_replicas, -1).detach() + ) if self.store_velocities: start = stop stop += simulator.system.total_n_atoms * 3 - self.buffer[ - buffer_position : buffer_position + 1, :, start:stop - ] = simulator.system.velocities.view( - simulator.system.n_replicas, -1 - ).detach() + self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( + simulator.system.velocities.view( + simulator.system.n_replicas, -1 + ).detach() + ) if self.cells: # Get cells start = stop stop += 9 * simulator.system.n_molecules - self.buffer[ - buffer_position : buffer_position + 1, :, start:stop - ] = simulator.system.cells.view(simulator.system.n_replicas, -1).detach() + self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( + simulator.system.cells.view(simulator.system.n_replicas, -1).detach() + ) # Get stress tensors start = stop stop += 9 * simulator.system.n_molecules - self.buffer[ - buffer_position : buffer_position + 1, :, start:stop - ] = simulator.system.stress.view(simulator.system.n_replicas, -1).detach() + self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( + simulator.system.stress.view(simulator.system.n_replicas, -1).detach() + ) class PropertyStream(DataStream): diff --git a/src/schnetpack/md/simulation_hooks/thermostats.py b/src/schnetpack/md/simulation_hooks/thermostats.py index 1b51c8606..5b42369be 100644 --- a/src/schnetpack/md/simulation_hooks/thermostats.py +++ b/src/schnetpack/md/simulation_hooks/thermostats.py @@ -2,6 +2,7 @@ This module contains various thermostats for regulating the temperature of the system during molecular dynamics simulations. """ + from __future__ import annotations import torch import numpy as np diff --git a/src/schnetpack/md/simulation_hooks/thermostats_rpmd.py b/src/schnetpack/md/simulation_hooks/thermostats_rpmd.py index 301dd482d..d17f7f166 100644 --- a/src/schnetpack/md/simulation_hooks/thermostats_rpmd.py +++ b/src/schnetpack/md/simulation_hooks/thermostats_rpmd.py @@ -2,6 +2,7 @@ This module contains pecialized thermostats for controlling temperature of the system during ring polymer molecular dynamics simulations. """ + from __future__ import annotations import torch diff --git a/src/schnetpack/md/simulator.py b/src/schnetpack/md/simulator.py index bddd5f499..1b1068893 100644 --- a/src/schnetpack/md/simulator.py +++ b/src/schnetpack/md/simulator.py @@ -4,6 +4,7 @@ integrators (:obj:`schnetpack.md.integrators`) and various simulation hooks (:obj:`schnetpack.md.simulation_hooks`) and performs the time integration. """ + import torch import torch.nn as nn from contextlib import nullcontext diff --git a/src/schnetpack/md/system.py b/src/schnetpack/md/system.py index e8b708e64..f051881cd 100644 --- a/src/schnetpack/md/system.py +++ b/src/schnetpack/md/system.py @@ -3,6 +3,7 @@ It includes functionality for loading molecules from files. All this functionality is encoded in the :obj:`schnetpack.md.System` class. """ + import torch import torch.nn as nn diff --git a/src/schnetpack/md/utils/__init__.py b/src/schnetpack/md/utils/__init__.py index fa8d38758..33f1539d9 100644 --- a/src/schnetpack/md/utils/__init__.py +++ b/src/schnetpack/md/utils/__init__.py @@ -60,9 +60,9 @@ def activate_model_stress( module.derivative_instructions["dEds"] = True module.basic_derivatives["dEds"] = schnetpack.properties.strain - module.map_properties[ + module.map_properties[schnetpack.properties.stress] = ( schnetpack.properties.stress - ] = schnetpack.properties.stress + ) # append stress label to output list and update required derivatives in the module module.model_outputs.append(stress_key) diff --git a/src/schnetpack/nn/activations.py b/src/schnetpack/nn/activations.py index be7cc45b0..507577f94 100644 --- a/src/schnetpack/nn/activations.py +++ b/src/schnetpack/nn/activations.py @@ -49,12 +49,12 @@ class ShiftedSoftplus(torch.nn.Module): """ def __init__( - self, + self, initial_alpha: float = 1.0, initial_beta: float = 1.0, - trainable: bool = False) -> None: - - """ + trainable: bool = False, + ) -> None: + """ Args: initial_alpha: Initial "scale" alpha of the softplus function. initial_beta: Initial "temperature" beta of the softplus function. @@ -86,4 +86,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.beta != 0, (torch.nn.functional.softplus(self.beta * x) - math.log(2)) / self.beta, 0.5 * x, - ) \ No newline at end of file + ) diff --git a/src/schnetpack/nn/blocks.py b/src/schnetpack/nn/blocks.py index 8b23df4ce..6438bac6f 100644 --- a/src/schnetpack/nn/blocks.py +++ b/src/schnetpack/nn/blocks.py @@ -169,7 +169,7 @@ def __init__( bias: bool = True, zero_init: bool = True, ) -> None: - """ + """ Args: num_features: Dimensions of feature space. activation: activation function @@ -177,14 +177,14 @@ def __init__( super(Residual, self).__init__() # initialize attributes - self.activation1 = activation#(num_features) + self.activation1 = activation # (num_features) self.linear1 = nn.Linear(num_features, num_features, bias=bias) - self.activation2 = activation#(num_features) + self.activation2 = activation # (num_features) self.linear2 = nn.Linear(num_features, num_features, bias=bias) self.reset_parameters(bias, zero_init) def reset_parameters(self, bias: bool = True, zero_init: bool = True) -> None: - """ Initialize parameters to compute an identity mapping. """ + """Initialize parameters to compute an identity mapping.""" nn.init.orthogonal_(self.linear1.weight) if zero_init: nn.init.zeros_(self.linear2.weight) @@ -224,8 +224,9 @@ def __init__( num_residual: int, activation: Union[Callable, nn.Module], bias: bool = True, - zero_init: bool = True) -> None: - """ + zero_init: bool = True, + ) -> None: + """ Args: num_blocks: Number of residual blocks to be stacked in sequence. num_features: Dimensions of feature space. @@ -254,18 +255,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for residual in self.stack: x = residual(x) return x - + class ResidualMLP(nn.Module): """Residual MLP with num_residual residual blocks.""" + def __init__( self, num_features: int, num_residual: int, activation: Union[Callable, nn.Module], bias: bool = True, - zero_init: bool = False): - + zero_init: bool = False, + ): """ Args: num_features: Dimensions of feature space. @@ -291,4 +293,4 @@ def reset_parameters(self, bias: bool = True, zero_init: bool = False) -> None: nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(self.activation(self.residual(x))) \ No newline at end of file + return self.linear(self.activation(self.residual(x))) diff --git a/src/schnetpack/nn/embedding.py b/src/schnetpack/nn/embedding.py index 49f46abd2..752037d86 100644 --- a/src/schnetpack/nn/embedding.py +++ b/src/schnetpack/nn/embedding.py @@ -12,7 +12,7 @@ __all__ = ["NuclearEmbedding", "ElectronicEmbedding"] -''' +""" The usage of the electron configuration is to provide a shorthand descriptor. This descriptor encode information about the groundstate information of an atom, the nuclear charge and the number of electrons in the valence shell. @@ -39,7 +39,7 @@ (Indicated by the same pattern in the electron configuration) -''' +""" # fmt: off # up until Z = 100; vs = valence s, vp = valence p, vd = valence d, vf = valence f. @@ -149,9 +149,9 @@ [100, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 12,0, 2, 0, 0, 12] # Fm ], dtype=np.float32) -# fmt: on -# normalize entries (between 0.0 and 1.0) -# normalization just for numerical reasons +# fmt: on +# normalize entries (between 0.0 and 1.0) +# normalization just for numerical reasons electron_config = electron_config / np.max(electron_config, axis=0) @@ -169,16 +169,12 @@ class NuclearEmbedding(nn.Module): The model will converge to a lower value, but the duration is longer. """ - def __init__( - self, - max_z: int, - num_features: int, - zero_init: bool = True): - """ + def __init__(self, max_z: int, num_features: int, zero_init: bool = True): + """ Args: num_features: Dimensions of feature space. Zmax: Maximum nuclear charge of atoms. The default is 100, so all - elements up to Fermium (Fe) (Z=100) are supported. + elements up to Fermium (Fe) (Z=100) are supported. Can be kept at the default value (has minimal memory impact). zero_init: If True, initialize the embedding with zeros. Otherwise, use uniform initialization. @@ -186,13 +182,19 @@ def __init__( super(NuclearEmbedding, self).__init__() self.num_features = num_features self.register_buffer("electron_config", torch.tensor(electron_config)) - self.register_parameter("element_embedding", nn.Parameter(torch.Tensor(max_z, self.num_features))) - self.register_buffer("embedding", torch.Tensor(max_z, self.num_features), persistent=False) - self.config_linear = nn.Linear(self.electron_config.size(1), self.num_features, bias=False) + self.register_parameter( + "element_embedding", nn.Parameter(torch.Tensor(max_z, self.num_features)) + ) + self.register_buffer( + "embedding", torch.Tensor(max_z, self.num_features), persistent=False + ) + self.config_linear = nn.Linear( + self.electron_config.size(1), self.num_features, bias=False + ) self.reset_parameters(zero_init) def reset_parameters(self, zero_init: bool = True) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" if zero_init: nn.init.zeros_(self.element_embedding) nn.init.zeros_(self.config_linear.weight) @@ -201,7 +203,7 @@ def reset_parameters(self, zero_init: bool = True) -> None: nn.init.orthogonal_(self.config_linear.weight) def train(self, mode: bool = True) -> None: - """ Switch between training and evaluation mode. """ + """Switch between training and evaluation mode.""" super(NuclearEmbedding, self).train(mode=mode) if not self.training: with torch.no_grad(): @@ -228,7 +230,9 @@ def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor: return self.embedding[atomic_numbers] else: # gathering is faster on GPUs return torch.gather( - self.embedding, 0, atomic_numbers.view(-1, 1).expand(-1, self.num_features) + self.embedding, + 0, + atomic_numbers.view(-1, 1).expand(-1, self.num_features), ) @@ -246,7 +250,7 @@ def __init__( property_key: str, num_features: int, is_charged: bool, - num_residual: int = 1, + num_residual: int = 1, activation: Union[Callable, nn.Module] = shifted_softplus, epsilon: float = 1e-8, ): @@ -284,7 +288,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.linear_k.weight) nn.init.orthogonal_(self.linear_v.weight) nn.init.orthogonal_(self.linear_q.weight) @@ -310,13 +314,13 @@ def forward( # queries (Batchsize x N_atoms, n_atom_basis) q = self.linear_q(input_embedding) - + # to account for negative and positive charge if self.is_charged: e = F.relu(torch.stack([electronic_feature, -electronic_feature], dim=-1)) # +/- spin is the same => abs else: - e = torch.abs(electronic_feature).unsqueeze(-1) + e = torch.abs(electronic_feature).unsqueeze(-1) enorm = torch.maximum(e, torch.ones_like(e)) # keys (Batchsize x N_atoms, n_atom_basis), the idx_m ensures that the key is the same for all atoms belonging to the same graph @@ -330,12 +334,14 @@ def forward( weights = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # probability distribution of scaled unnormalized attention weights, by applying softmax function - a = nn.functional.softmax(weights, dim=0) # nn.functional.softplus(weights) seems to function to but softmax might be more stable + a = nn.functional.softmax( + weights, dim=0 + ) # nn.functional.softplus(weights) seems to function to but softmax might be more stable # normalization factor for every molecular graph, by adding up attention weights of every atom in the graph anorm = a.new_zeros(num_batch).index_add_(0, idx_m, a) - # make tensor filled with anorm value at the position of the corresponding molecular graph, + # make tensor filled with anorm value at the position of the corresponding molecular graph, # indexing faster on CPU, gather faster on GPU - if a.device.type == "cpu": + if a.device.type == "cpu": anorm = anorm[idx_m] else: anorm = torch.gather(anorm, 0, idx_m) diff --git a/src/schnetpack/nn/so3.py b/src/schnetpack/nn/so3.py index 32d7896b8..9b8ca3e29 100644 --- a/src/schnetpack/nn/so3.py +++ b/src/schnetpack/nn/so3.py @@ -300,7 +300,9 @@ def forward( * self.clebsch_gordan[None, :, None] * xj ) - yij = snn.scatter_add(v, self.idx_out, dim_size=int((self.lmax + 1) ** 2), dim=1) + yij = snn.scatter_add( + v, self.idx_out, dim_size=int((self.lmax + 1) ** 2), dim=1 + ) y = snn.scatter_add(yij, idx_i, dim_size=x.shape[0]) return y diff --git a/src/schnetpack/properties.py b/src/schnetpack/properties.py index 20d2a0a79..d990da1af 100644 --- a/src/schnetpack/properties.py +++ b/src/schnetpack/properties.py @@ -4,6 +4,7 @@ Note: Had to be moved out of Structure class for TorchScript compatibility """ + from typing import Final idx: Final[str] = "_idx" @@ -25,13 +26,13 @@ idx_j_lr: Final[str] = "_idx_j_lr" #: indices of neighboring atoms for long-range lidx_i: Final[str] = "_idx_i_local" #: local indices of center atoms (within system) -lidx_j: Final[ - str -] = "_idx_j_local" #: local indices of neighboring atoms (within system) +lidx_j: Final[str] = ( + "_idx_j_local" #: local indices of neighboring atoms (within system) +) Rij: Final[str] = "_Rij" #: vectors pointing from center atoms to neighboring atoms -Rij_lr: Final[ - str -] = "_Rij_lr" #: vectors pointing from center atoms to neighboring atoms for long range +Rij_lr: Final[str] = ( + "_Rij_lr" #: vectors pointing from center atoms to neighboring atoms for long range +) n_atoms: Final[str] = "_n_atoms" #: number of atoms offsets: Final[str] = "_offsets" #: cell offset vectors offsets_lr: Final[str] = "_offsets_lr" #: cell offset vectors for long range diff --git a/src/schnetpack/representation/painn.py b/src/schnetpack/representation/painn.py index 383167545..c57d413e6 100644 --- a/src/schnetpack/representation/painn.py +++ b/src/schnetpack/representation/painn.py @@ -174,7 +174,7 @@ def __init__( electronic_embeddings = [] electronic_embeddings = nn.ModuleList(electronic_embeddings) self.electronic_embeddings = electronic_embeddings - + # initialize filter layers self.share_filters = shared_filters if shared_filters: diff --git a/src/schnetpack/representation/schnet.py b/src/schnetpack/representation/schnet.py index 7a2f3906c..fd5e88a57 100644 --- a/src/schnetpack/representation/schnet.py +++ b/src/schnetpack/representation/schnet.py @@ -70,7 +70,6 @@ def forward( return x - class SchNet(nn.Module): """SchNet architecture for learning representations of atomistic systems diff --git a/src/schnetpack/representation/so3net.py b/src/schnetpack/representation/so3net.py index dd3dea783..f8905566b 100644 --- a/src/schnetpack/representation/so3net.py +++ b/src/schnetpack/representation/so3net.py @@ -135,7 +135,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]): # compute interaction blocks and update atomic embeddings x = so3.scalar2rsh(x0, int(self.lmax)) for so3conv, mixing1, mixing2, gating, mixing3 in zip( - self.so3convs, self.mixings1, self.mixings2, self.gatings, self.mixings3 + self.so3convs, self.mixings1, self.mixings2, self.gatings, self.mixings3 ): dx = so3conv(x, radial_ij, Yij, cutoff_ij, idx_i, idx_j) ddx = mixing1(dx) diff --git a/src/schnetpack/task.py b/src/schnetpack/task.py index c0db96def..5d92a1d9f 100644 --- a/src/schnetpack/task.py +++ b/src/schnetpack/task.py @@ -202,7 +202,14 @@ def validation_step(self, batch, batch_idx): loss = self.loss_fn(pred, targets) - self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch['_idx'])) + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch["_idx"]), + ) self.log_metrics(pred, targets, "val") return {"val_loss": loss} @@ -225,7 +232,14 @@ def test_step(self, batch, batch_idx): loss = self.loss_fn(pred, targets) - self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch['_idx'])) + self.log( + "test_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch["_idx"]), + ) self.log_metrics(pred, targets, "test") return {"test_loss": loss} diff --git a/src/schnetpack/train/callbacks.py b/src/schnetpack/train/callbacks.py index 9bafc3183..10f7e13eb 100644 --- a/src/schnetpack/train/callbacks.py +++ b/src/schnetpack/train/callbacks.py @@ -24,10 +24,10 @@ class PredictionWriter(BasePredictionWriter): """ def __init__( - self, - output_dir: str, - write_interval: str, - write_idx: bool = False, + self, + output_dir: str, + write_interval: str, + write_idx: bool = False, ): """ Args: diff --git a/src/schnetpack/transform/neighborlist.py b/src/schnetpack/transform/neighborlist.py index a75d7ca62..870b63b8b 100644 --- a/src/schnetpack/transform/neighborlist.py +++ b/src/schnetpack/transform/neighborlist.py @@ -355,15 +355,9 @@ def _update(self, inputs): < 0.25 * self.cutoff_skin**2 ): # reuse previous neighbor list - inputs[properties.idx_i] = ( - previous_inputs[properties.idx_i].clone() - ) - inputs[properties.idx_j] = ( - previous_inputs[properties.idx_j].clone() - ) - inputs[properties.offsets] = ( - previous_inputs[properties.offsets].clone() - ) + inputs[properties.idx_i] = previous_inputs[properties.idx_i].clone() + inputs[properties.idx_j] = previous_inputs[properties.idx_j].clone() + inputs[properties.offsets] = previous_inputs[properties.offsets].clone() return False, inputs # build new neighbor list