Skip to content

Commit

Permalink
lint: unsafe fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Aug 20, 2024
1 parent 60f207d commit a02eacc
Show file tree
Hide file tree
Showing 35 changed files with 349 additions and 366 deletions.
18 changes: 9 additions & 9 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@

import os
import urllib.request
from pathlib import Path
from typing import Union
from typing import TYPE_CHECKING

import torch
from ase import units
from ase.calculators.mixing import SumCalculator

from .mace import MACECalculator

if TYPE_CHECKING:
from pathlib import Path

module_dir = os.path.dirname(__file__)
local_model_path = os.path.join(module_dir, "foundations_models/2023-12-03-mace-mp.model")


def mace_mp(
model: Union[str, Path] = None,
model: str | Path | None = None,
device: str = "",
default_dtype: str = "float32",
dispersion: bool = False,
Expand Down Expand Up @@ -111,12 +113,11 @@ def mace_mp(
cutoff=dispersion_cutoff,
**kwargs,
)
calc = mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc])
return calc
return mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc])


def mace_off(
model: Union[str, Path] = None,
model: str | Path | None = None,
device: str = "",
default_dtype: str = "float64",
return_raw_model: bool = False,
Expand Down Expand Up @@ -179,13 +180,12 @@ def mace_off(
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)
mace_calc = MACECalculator(model_paths=model, device=device, default_dtype=default_dtype, **kwargs)
return mace_calc
return MACECalculator(model_paths=model, device=device, default_dtype=default_dtype, **kwargs)


def mace_anicc(
device: str = "cuda",
model_path: str = None,
model_path: str | None = None,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
Expand Down
27 changes: 8 additions & 19 deletions mace/calculators/lammps_mace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Dict, List, Optional

import torch
from e3nn.util.jit import compile_mode

Expand All @@ -21,10 +19,10 @@ def __init__(self, model):

def forward(
self,
data: Dict[str, torch.Tensor],
data: dict[str, torch.Tensor],
local_or_ghost: torch.Tensor,
compute_virials: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
) -> dict[str, torch.Tensor | None]:
num_graphs = data["ptr"].numel() - 1
compute_displacement = False
if compute_virials:
Expand All @@ -47,13 +45,13 @@ def forward(
}
positions = data["positions"]
displacement = out["displacement"]
forces: Optional[torch.Tensor] = torch.zeros_like(positions)
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
forces: torch.Tensor | None = torch.zeros_like(positions)
virials: torch.Tensor | None = torch.zeros_like(data["cell"])
# accumulate energies of local atoms
node_energy_local = node_energy * local_or_ghost
total_energy_local = scatter_sum(src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs)
# compute partial forces and (possibly) partial virials
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(total_energy_local)]
grad_outputs: list[torch.Tensor | None] = [torch.ones_like(total_energy_local)]
if compute_virials and displacement is not None:
forces, virials = torch.autograd.grad(
outputs=[total_energy_local],
Expand All @@ -63,14 +61,8 @@ def forward(
create_graph=False,
allow_unused=True,
)
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
if virials is not None:
virials = -1 * virials
else:
virials = torch.zeros_like(displacement)
forces = -1 * forces if forces is not None else torch.zeros_like(positions)
virials = -1 * virials if virials is not None else torch.zeros_like(displacement)
else:
forces = torch.autograd.grad(
outputs=[total_energy_local],
Expand All @@ -80,10 +72,7 @@ def forward(
create_graph=False,
allow_unused=True,
)[0]
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
forces = -1 * forces if forces is not None else torch.zeros_like(positions)
return {
"total_energy_local": total_energy_local,
"node_energy": node_energy,
Expand Down
6 changes: 2 additions & 4 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from glob import glob
from pathlib import Path
from typing import Union

import numpy as np
import torch
Expand Down Expand Up @@ -49,7 +48,7 @@ class MACECalculator(Calculator):

def __init__(
self,
model_paths: Union[list, str],
model_paths: list | str,
device: str,
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
Expand Down Expand Up @@ -188,8 +187,7 @@ def _atoms_to_batch(self, atoms):
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(self.device)
return batch
return next(iter(data_loader)).to(self.device)

def _clone_batch(self, batch):
batch_clone = batch.clone()
Expand Down
10 changes: 3 additions & 7 deletions mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,11 @@ def printenergy(dyn, start_time=None): # store a reference to atoms in the defi
a = dyn.atoms
epot = a.get_potential_energy() / len(a)
ekin = a.get_kinetic_energy() / len(a)
if start_time is None:
elapsed_time = 0
else:
elapsed_time = time.time() - start_time
elapsed_time = 0 if start_time is None else time.time() - start_time
forces_var = np.var(a.calc.results["forces_comm"], axis=0)
print(
"%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209
"Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A"
% (
"{:.1f}s: Energy per atom: Epot = {:.3f}eV Ekin = {:.3f}eV (T={:3.0f}K) " # pylint: disable=C0209
"Etot = {:.3f}eV t={:.1f}fs Eerr = {:.3f}eV Ferr = {:.3f}eV/A".format(
elapsed_time,
epot,
ekin,
Expand Down
5 changes: 2 additions & 3 deletions mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import json
import os
import re
from typing import List

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -48,7 +47,7 @@ def parse_path(path: str) -> RunInfo:
return RunInfo(name=match.group("name"), seed=int(match.group("seed")))


def parse_training_results(path: str) -> List[dict]:
def parse_training_results(path: str) -> list[dict]:
run_info = parse_path(path)
results = []
with open(path, encoding="utf-8") as f:
Expand Down Expand Up @@ -155,7 +154,7 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None:
plt.close(fig)


def get_paths(path: str) -> List[str]:
def get_paths(path: str) -> list[str]:
if os.path.isfile(path):
return [path]
paths = glob.glob(os.path.join(path, "*_train.txt"))
Expand Down
16 changes: 9 additions & 7 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# new hdf5 file that is ready for training with on-the-fly dataloading
from __future__ import annotations

import argparse
import ast
import json
import logging
Expand All @@ -11,7 +10,7 @@
import random
from functools import partial
from glob import glob
from typing import List, Tuple
from typing import TYPE_CHECKING

import h5py
import numpy as np
Expand All @@ -22,14 +21,18 @@
from mace.modules import compute_statistics
from mace.tools import torch_geometric
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz
from mace.tools.utils import AtomicNumberTable

if TYPE_CHECKING:
import argparse

from mace.tools.utils import AtomicNumberTable


def compute_stats_target(
file: str,
z_table: AtomicNumberTable,
r_max: float,
atomic_energies: Tuple,
atomic_energies: tuple,
batch_size: int,
):
train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max)
Expand All @@ -41,11 +44,10 @@ def compute_stats_target(
)

avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies)
output = [avg_num_neighbors, mean, std]
return output
return [avg_num_neighbors, mean, std]


def pool_compute_stats(inputs: List):
def pool_compute_stats(inputs: list):
path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs

with mp.Pool(processes=num_process) as pool:
Expand Down
18 changes: 9 additions & 9 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
###########################################################################################
from __future__ import annotations

import argparse
import ast
import glob
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING

import numpy as np
import torch.distributed
Expand Down Expand Up @@ -43,6 +42,9 @@
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.utils import AtomicNumberTable

if TYPE_CHECKING:
import argparse


def main() -> None:
"""
Expand Down Expand Up @@ -299,7 +301,8 @@ def run(args: argparse.Namespace) -> None:
dipole_weight=args.dipole_weight,
)
elif args.loss == "energy_forces_dipole":
assert dipole_only is False and compute_dipole is True
assert dipole_only is False
assert compute_dipole is True
loss_fn = modules.WeightedEnergyForcesDipoleLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
Expand Down Expand Up @@ -533,7 +536,7 @@ def run(args: argparse.Namespace) -> None:

lr_scheduler = LRScheduler(optimizer, args)

swa: Optional[tools.SWAContainer] = None
swa: tools.SWAContainer | None = None
swas = [False]
if args.swa:
assert dipole_only is False, "Stage Two for dipole fitting not implemented"
Expand Down Expand Up @@ -613,7 +616,7 @@ def run(args: argparse.Namespace) -> None:
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

ema: Optional[ExponentialMovingAverage] = None
ema: ExponentialMovingAverage | None = None
if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
else:
Expand Down Expand Up @@ -642,10 +645,7 @@ def run(args: argparse.Namespace) -> None:
)
wandb.run.summary["params"] = args_dict_json

if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None
distributed_model = DDP(model, device_ids=[local_rank]) if args.distributed else None

tools.train(
model=model,
Expand Down
33 changes: 18 additions & 15 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
###########################################################################################
from __future__ import annotations

from typing import Optional, Sequence
from typing import TYPE_CHECKING, Sequence

import torch.utils.data

from mace.tools import AtomicNumberTable, atomic_numbers_to_indices, to_one_hot, torch_geometric, voigt_to_matrix

from .neighborhood import get_neighborhood
from .utils import Configuration

if TYPE_CHECKING:
from .utils import Configuration


class AtomicData(torch_geometric.data.Data):
Expand Down Expand Up @@ -45,23 +47,24 @@ def __init__(
positions: torch.Tensor, # [n_nodes, 3]
shifts: torch.Tensor, # [n_edges, 3],
unit_shifts: torch.Tensor, # [n_edges, 3]
cell: Optional[torch.Tensor], # [3,3]
weight: Optional[torch.Tensor], # [,]
energy_weight: Optional[torch.Tensor], # [,]
forces_weight: Optional[torch.Tensor], # [,]
stress_weight: Optional[torch.Tensor], # [,]
virials_weight: Optional[torch.Tensor], # [,]
forces: Optional[torch.Tensor], # [n_nodes, 3]
energy: Optional[torch.Tensor], # [, ]
stress: Optional[torch.Tensor], # [1,3,3]
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
cell: torch.Tensor | None, # [3,3]
weight: torch.Tensor | None, # [,]
energy_weight: torch.Tensor | None, # [,]
forces_weight: torch.Tensor | None, # [,]
stress_weight: torch.Tensor | None, # [,]
virials_weight: torch.Tensor | None, # [,]
forces: torch.Tensor | None, # [n_nodes, 3]
energy: torch.Tensor | None, # [, ]
stress: torch.Tensor | None, # [1,3,3]
virials: torch.Tensor | None, # [1,3,3]
dipole: torch.Tensor | None, # [, 3]
charges: torch.Tensor | None, # [n_nodes, ]
):
# Check shapes
num_nodes = node_attrs.shape[0]

assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
assert edge_index.shape[0] == 2
assert len(edge_index.shape) == 2
assert positions.shape == (num_nodes, 3)
assert shifts.shape[1] == 3
assert unit_shifts.shape[1] == 3
Expand Down
Loading

0 comments on commit a02eacc

Please sign in to comment.