Skip to content

Commit

Permalink
fix ruff FBT001 FBT002
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed May 19, 2024
1 parent 6c83227 commit 934b71f
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 11 deletions.
12 changes: 11 additions & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
structures: list[Structure],
energies: list[float],
forces: list[Sequence[Sequence[float]]],
*,
stresses: list[Sequence[Sequence[float]]] | None = None,
magmoms: list[Sequence[Sequence[float]]] | None = None,
structure_ids: list | None = None,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
def from_vasp(
cls,
file_root: str,
*,
check_electronic_convergence: bool = True,
save_path: str | None = None,
graph_converter: CrystalGraphConverter | None = None,
Expand Down Expand Up @@ -196,6 +198,7 @@ class CIFData(Dataset):
def __init__(
self,
cif_path: str,
*,
labels: str | dict = "labels.json",
targets: TrainTask = "efsm",
graph_converter: CrystalGraphConverter | None = None,
Expand Down Expand Up @@ -311,6 +314,7 @@ class GraphData(Dataset):
def __init__(
self,
graph_path: str,
*,
labels: str | dict = "labels.json",
targets: TrainTask = "efsm",
exclude: str | list | None = None,
Expand Down Expand Up @@ -429,6 +433,7 @@ def get_train_val_test_loader(
self,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
*,
train_key: list[str] | None = None,
val_key: list[str] | None = None,
test_key: list[str] | None = None,
Expand Down Expand Up @@ -541,6 +546,7 @@ def __init__(
self,
data: str | dict,
graph_converter: CrystalGraphConverter,
*,
targets: TrainTask = "efsm",
energy_key: str = "energy_per_atom",
force_key: str = "force",
Expand Down Expand Up @@ -654,6 +660,7 @@ def get_train_val_test_loader(
self,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
*,
train_key: list[str] | None = None,
val_key: list[str] | None = None,
test_key: list[str] | None = None,
Expand Down Expand Up @@ -777,6 +784,7 @@ def collate_graphs(batch_data: list):

def get_train_val_test_loader(
dataset: Dataset,
*,
batch_size: int = 64,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
Expand Down Expand Up @@ -842,7 +850,9 @@ def get_train_val_test_loader(
return train_loader, val_loader


def get_loader(dataset, batch_size=64, num_workers=0, pin_memory=True):
def get_loader(
dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True
):
"""Get a dataloader from a dataset.
Args:
Expand Down
1 change: 1 addition & 0 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CrystalGraphConverter(nn.Module):

def __init__(
self,
*,
atom_graph_cutoff: float = 6,
bond_graph_cutoff: float = 3,
algorithm: Literal["legacy", "fast"] = "fast",
Expand Down
5 changes: 3 additions & 2 deletions chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Fourier(nn.Module):
"""Fourier Expansion for angle features."""

def __init__(self, order: int = 5, learnable: bool = False) -> None:
def __init__(self, *, order: int = 5, learnable: bool = False) -> None:
"""Initialize the Fourier expansion.
Args:
Expand Down Expand Up @@ -47,6 +47,7 @@ class RadialBessel(torch.nn.Module):

def __init__(
self,
*,
num_radial: int = 9,
cutoff: float = 5,
learnable: bool = False,
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
self.smooth_cutoff = None

def forward(
self, dist: Tensor, return_smooth_factor: bool = False
self, dist: Tensor, *, return_smooth_factor: bool = False
) -> Tensor | tuple[Tensor, Tensor]:
"""Apply Bessel expansion to a feature Tensor.
Expand Down
5 changes: 4 additions & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CompositionModel(nn.Module):

def __init__(
self,
*,
atom_fea_dim: int = 64,
activation: str = "silu",
is_intensive: bool = True,
Expand Down Expand Up @@ -88,7 +89,9 @@ class AtomRef(nn.Module):
From: https://github.com/materialsvirtuallab/m3gnet/.
"""

def __init__(self, is_intensive: bool = True, max_num_elements: int = 94) -> None:
def __init__(
self, *, is_intensive: bool = True, max_num_elements: int = 94
) -> None:
"""Initialize an AtomRef model."""
super().__init__()
self.is_intensive = is_intensive
Expand Down
6 changes: 5 additions & 1 deletion chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CHGNetCalculator(Calculator):
def __init__(
self,
model: CHGNet | None = None,
*,
use_device: str | None = None,
check_cuda_mem: bool = True,
stress_weight: float | None = 1 / 160.21766208,
Expand Down Expand Up @@ -215,6 +216,7 @@ def n_params(self) -> int:
def relax(
self,
atoms: Structure | Atoms,
*,
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
Expand Down Expand Up @@ -419,6 +421,7 @@ class MolecularDynamics:
def __init__(
self,
atoms: Atoms | Structure,
*,
model: CHGNet | CHGNetCalculator | None = None,
ensemble: str = "nvt",
thermostat: str = "Berendsen_inhomogeneous",
Expand Down Expand Up @@ -729,7 +732,7 @@ def set_atoms(self, atoms: Atoms) -> None:
self.dyn.atoms = atoms
self.dyn.atoms.calc = calculator

def upper_triangular_cell(self, verbose: bool | None = False) -> None:
def upper_triangular_cell(self, *, verbose: bool | None = False) -> None:
"""Transform to upper-triangular cell.
ASE Nose-Hoover implementation only supports upper-triangular cell
while ASE's canonical description is lower-triangular cell.
Expand Down Expand Up @@ -799,6 +802,7 @@ def __init__(
def fit(
self,
atoms: Structure | Atoms,
*,
n_points: int = 11,
fmax: float | None = 0.1,
steps: int | None = 500,
Expand Down
3 changes: 2 additions & 1 deletion chgnet/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class BondEncoder(nn.Module):

def __init__(
self,
*,
atom_graph_cutoff: float = 5,
bond_graph_cutoff: float = 3,
num_radial: int = 9,
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(
class AngleEncoder(nn.Module):
"""Encode an angle given the two bond vectors using Fourier Expansion."""

def __init__(self, num_angular: int = 9, learnable: bool = True) -> None:
def __init__(self, *, num_angular: int = 9, learnable: bool = True) -> None:
"""Initialize the angle encoder.
Args:
Expand Down
4 changes: 3 additions & 1 deletion chgnet/model/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor, nn


def aggregate(data: Tensor, owners: Tensor, average=True, num_owner=None) -> Tensor:
def aggregate(data: Tensor, owners: Tensor, *, average=True, num_owner=None) -> Tensor:
"""Aggregate rows in data by specifying the owners.
Args:
Expand Down Expand Up @@ -45,6 +45,7 @@ class MLP(nn.Module):
def __init__(
self,
input_dim: int,
*,
output_dim: int = 1,
hidden_dim: int | Sequence[int] | None = (64, 64),
dropout: float = 0,
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
self,
input_dim: int,
output_dim: int,
*,
hidden_dim: int | list[int] | None = None,
dropout: float = 0,
activation: str = "silu",
Expand Down
12 changes: 10 additions & 2 deletions chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AtomConv(nn.Module):

def __init__(
self,
*,
atom_fea_dim: int,
bond_fea_dim: int,
hidden_dim: int = 64,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
atom_fea_dim: int,
bond_fea_dim: int,
angle_fea_dim: int,
*,
hidden_dim: int = 64,
dropout: float = 0,
activation: str = "silu",
Expand Down Expand Up @@ -271,6 +273,7 @@ def __init__(
atom_fea_dim: int,
bond_fea_dim: int,
angle_fea_dim: int,
*,
hidden_dim: int = 0,
dropout: float = 0,
activation: str = "silu",
Expand Down Expand Up @@ -363,7 +366,7 @@ def forward(
class GraphPooling(nn.Module):
"""Pooling the sub-graphs in the batched graph."""

def __init__(self, average: bool = False) -> None:
def __init__(self, *, average: bool = False) -> None:
"""Args:
average (bool): whether to average the features.
"""
Expand Down Expand Up @@ -392,7 +395,12 @@ class GraphAttentionReadOut(nn.Module):
"""

def __init__(
self, atom_fea_dim: int, num_head: int = 3, hidden_dim: int = 32, average=False
self,
atom_fea_dim: int,
num_head: int = 3,
hidden_dim: int = 32,
*,
average=False,
) -> None:
"""Initialize the layer.
Expand Down
11 changes: 9 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CHGNet(nn.Module):

def __init__(
self,
*,
atom_fea_dim: int = 64,
bond_fea_dim: int = 64,
angle_fea_dim: int = 64,
Expand Down Expand Up @@ -327,6 +328,7 @@ def n_params(self) -> int:
def forward(
self,
graphs: Sequence[CrystalGraph],
*,
task: PredTask = "e",
return_site_energies: bool = False,
return_atom_feas: bool = False,
Expand Down Expand Up @@ -381,7 +383,8 @@ def forward(

def _compute(
self,
g,
g: BatchedGraph,
*,
compute_force: bool = False,
compute_stress: bool = False,
compute_magmom: bool = False,
Expand Down Expand Up @@ -530,6 +533,7 @@ def _compute(
def predict_structure(
self,
structure: Structure | Sequence[Structure],
*,
task: PredTask = "efsm",
return_site_energies: bool = False,
return_atom_feas: bool = False,
Expand Down Expand Up @@ -578,6 +582,7 @@ def predict_structure(
def predict_graph(
self,
graph: CrystalGraph | Sequence[CrystalGraph],
*,
task: PredTask = "efsm",
return_site_energies: bool = False,
return_atom_feas: bool = False,
Expand Down Expand Up @@ -671,7 +676,8 @@ def from_file(cls, path: str, **kwargs) -> CHGNet:
@classmethod
def load(
cls,
model_name="0.3.0",
*,
model_name: str = "0.3.0",
use_device: str | None = None,
check_cuda_mem: bool = True,
verbose: bool = True,
Expand Down Expand Up @@ -769,6 +775,7 @@ def from_graphs(
graphs: Sequence[CrystalGraph],
bond_basis_expansion: nn.Module,
angle_basis_expansion: nn.Module,
*,
compute_stress: bool = False,
) -> BatchedGraph:
"""Featurize and assemble a list of graphs.
Expand Down
4 changes: 4 additions & 0 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Trainer:
def __init__(
self,
model: CHGNet | None = None,
*,
targets: TrainTask = "ef",
energy_loss_ratio: float = 1,
force_loss_ratio: float = 1,
Expand Down Expand Up @@ -199,6 +200,7 @@ def train(
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader | None = None,
*,
save_dir: str | None = None,
save_test_result: bool = False,
train_composition_model: bool = False,
Expand Down Expand Up @@ -353,6 +355,7 @@ def _train(self, train_loader: DataLoader, current_epoch: int) -> dict:
def _validate(
self,
val_loader: DataLoader,
*,
is_test: bool = False,
test_result_save_path: str | None = None,
) -> dict:
Expand Down Expand Up @@ -572,6 +575,7 @@ class CombinedLoss(nn.Module):

def __init__(
self,
*,
target_str: str = "ef",
criterion: str = "MSE",
is_intensive: bool = True,
Expand Down
1 change: 1 addition & 0 deletions chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

def determine_device(
use_device: str | None = None,
*,
check_cuda_mem: bool = True,
):
"""Determine the device to use for torch model.
Expand Down
1 change: 1 addition & 0 deletions chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

def parse_vasp_dir(
base_dir: str,
*,
check_electronic_convergence: bool = True,
save_path: str | None = None,
) -> dict[str, list]:
Expand Down

0 comments on commit 934b71f

Please sign in to comment.