From 934b71f0249815be8fe7de3daa4466a3726faf26 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 19 May 2024 17:05:48 -0400 Subject: [PATCH] fix ruff FBT001 FBT002 --- chgnet/data/dataset.py | 12 +++++++++++- chgnet/graph/converter.py | 1 + chgnet/model/basis.py | 5 +++-- chgnet/model/composition_model.py | 5 ++++- chgnet/model/dynamics.py | 6 +++++- chgnet/model/encoders.py | 3 ++- chgnet/model/functions.py | 4 +++- chgnet/model/layers.py | 12 ++++++++++-- chgnet/model/model.py | 11 +++++++++-- chgnet/trainer/trainer.py | 4 ++++ chgnet/utils/common_utils.py | 1 + chgnet/utils/vasp_utils.py | 1 + 12 files changed, 54 insertions(+), 11 deletions(-) diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index c30b3a3e..5655bdfd 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -33,6 +33,7 @@ def __init__( structures: list[Structure], energies: list[float], forces: list[Sequence[Sequence[float]]], + *, stresses: list[Sequence[Sequence[float]]] | None = None, magmoms: list[Sequence[Sequence[float]]] | None = None, structure_ids: list | None = None, @@ -91,6 +92,7 @@ def __init__( def from_vasp( cls, file_root: str, + *, check_electronic_convergence: bool = True, save_path: str | None = None, graph_converter: CrystalGraphConverter | None = None, @@ -196,6 +198,7 @@ class CIFData(Dataset): def __init__( self, cif_path: str, + *, labels: str | dict = "labels.json", targets: TrainTask = "efsm", graph_converter: CrystalGraphConverter | None = None, @@ -311,6 +314,7 @@ class GraphData(Dataset): def __init__( self, graph_path: str, + *, labels: str | dict = "labels.json", targets: TrainTask = "efsm", exclude: str | list | None = None, @@ -429,6 +433,7 @@ def get_train_val_test_loader( self, train_ratio: float = 0.8, val_ratio: float = 0.1, + *, train_key: list[str] | None = None, val_key: list[str] | None = None, test_key: list[str] | None = None, @@ -541,6 +546,7 @@ def __init__( self, data: str | dict, graph_converter: CrystalGraphConverter, + *, targets: TrainTask = "efsm", energy_key: str = "energy_per_atom", force_key: str = "force", @@ -654,6 +660,7 @@ def get_train_val_test_loader( self, train_ratio: float = 0.8, val_ratio: float = 0.1, + *, train_key: list[str] | None = None, val_key: list[str] | None = None, test_key: list[str] | None = None, @@ -777,6 +784,7 @@ def collate_graphs(batch_data: list): def get_train_val_test_loader( dataset: Dataset, + *, batch_size: int = 64, train_ratio: float = 0.8, val_ratio: float = 0.1, @@ -842,7 +850,9 @@ def get_train_val_test_loader( return train_loader, val_loader -def get_loader(dataset, batch_size=64, num_workers=0, pin_memory=True): +def get_loader( + dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True +): """Get a dataloader from a dataset. Args: diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 08df92d1..12a07441 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -33,6 +33,7 @@ class CrystalGraphConverter(nn.Module): def __init__( self, + *, atom_graph_cutoff: float = 6, bond_graph_cutoff: float = 3, algorithm: Literal["legacy", "fast"] = "fast", diff --git a/chgnet/model/basis.py b/chgnet/model/basis.py index 2468d2d5..10c08a11 100644 --- a/chgnet/model/basis.py +++ b/chgnet/model/basis.py @@ -8,7 +8,7 @@ class Fourier(nn.Module): """Fourier Expansion for angle features.""" - def __init__(self, order: int = 5, learnable: bool = False) -> None: + def __init__(self, *, order: int = 5, learnable: bool = False) -> None: """Initialize the Fourier expansion. Args: @@ -47,6 +47,7 @@ class RadialBessel(torch.nn.Module): def __init__( self, + *, num_radial: int = 9, cutoff: float = 5, learnable: bool = False, @@ -90,7 +91,7 @@ def __init__( self.smooth_cutoff = None def forward( - self, dist: Tensor, return_smooth_factor: bool = False + self, dist: Tensor, *, return_smooth_factor: bool = False ) -> Tensor | tuple[Tensor, Tensor]: """Apply Bessel expansion to a feature Tensor. diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 4b3882af..65edfff6 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -24,6 +24,7 @@ class CompositionModel(nn.Module): def __init__( self, + *, atom_fea_dim: int = 64, activation: str = "silu", is_intensive: bool = True, @@ -88,7 +89,9 @@ class AtomRef(nn.Module): From: https://github.com/materialsvirtuallab/m3gnet/. """ - def __init__(self, is_intensive: bool = True, max_num_elements: int = 94) -> None: + def __init__( + self, *, is_intensive: bool = True, max_num_elements: int = 94 + ) -> None: """Initialize an AtomRef model.""" super().__init__() self.is_intensive = is_intensive diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 7ad4cc29..38b34cde 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -55,6 +55,7 @@ class CHGNetCalculator(Calculator): def __init__( self, model: CHGNet | None = None, + *, use_device: str | None = None, check_cuda_mem: bool = True, stress_weight: float | None = 1 / 160.21766208, @@ -215,6 +216,7 @@ def n_params(self) -> int: def relax( self, atoms: Structure | Atoms, + *, fmax: float | None = 0.1, steps: int | None = 500, relax_cell: bool | None = True, @@ -419,6 +421,7 @@ class MolecularDynamics: def __init__( self, atoms: Atoms | Structure, + *, model: CHGNet | CHGNetCalculator | None = None, ensemble: str = "nvt", thermostat: str = "Berendsen_inhomogeneous", @@ -729,7 +732,7 @@ def set_atoms(self, atoms: Atoms) -> None: self.dyn.atoms = atoms self.dyn.atoms.calc = calculator - def upper_triangular_cell(self, verbose: bool | None = False) -> None: + def upper_triangular_cell(self, *, verbose: bool | None = False) -> None: """Transform to upper-triangular cell. ASE Nose-Hoover implementation only supports upper-triangular cell while ASE's canonical description is lower-triangular cell. @@ -799,6 +802,7 @@ def __init__( def fit( self, atoms: Structure | Atoms, + *, n_points: int = 11, fmax: float | None = 0.1, steps: int | None = 500, diff --git a/chgnet/model/encoders.py b/chgnet/model/encoders.py index d2eb4059..c68acefb 100644 --- a/chgnet/model/encoders.py +++ b/chgnet/model/encoders.py @@ -39,6 +39,7 @@ class BondEncoder(nn.Module): def __init__( self, + *, atom_graph_cutoff: float = 5, bond_graph_cutoff: float = 3, num_radial: int = 9, @@ -113,7 +114,7 @@ def forward( class AngleEncoder(nn.Module): """Encode an angle given the two bond vectors using Fourier Expansion.""" - def __init__(self, num_angular: int = 9, learnable: bool = True) -> None: + def __init__(self, *, num_angular: int = 9, learnable: bool = True) -> None: """Initialize the angle encoder. Args: diff --git a/chgnet/model/functions.py b/chgnet/model/functions.py index 1e9d84c4..764b4ccc 100644 --- a/chgnet/model/functions.py +++ b/chgnet/model/functions.py @@ -6,7 +6,7 @@ from torch import Tensor, nn -def aggregate(data: Tensor, owners: Tensor, average=True, num_owner=None) -> Tensor: +def aggregate(data: Tensor, owners: Tensor, *, average=True, num_owner=None) -> Tensor: """Aggregate rows in data by specifying the owners. Args: @@ -45,6 +45,7 @@ class MLP(nn.Module): def __init__( self, input_dim: int, + *, output_dim: int = 1, hidden_dim: int | Sequence[int] | None = (64, 64), dropout: float = 0, @@ -114,6 +115,7 @@ def __init__( self, input_dim: int, output_dim: int, + *, hidden_dim: int | list[int] | None = None, dropout: float = 0, activation: str = "silu", diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index 725bba5e..ebb055bb 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -17,6 +17,7 @@ class AtomConv(nn.Module): def __init__( self, + *, atom_fea_dim: int, bond_fea_dim: int, hidden_dim: int = 64, @@ -144,6 +145,7 @@ def __init__( atom_fea_dim: int, bond_fea_dim: int, angle_fea_dim: int, + *, hidden_dim: int = 64, dropout: float = 0, activation: str = "silu", @@ -271,6 +273,7 @@ def __init__( atom_fea_dim: int, bond_fea_dim: int, angle_fea_dim: int, + *, hidden_dim: int = 0, dropout: float = 0, activation: str = "silu", @@ -363,7 +366,7 @@ def forward( class GraphPooling(nn.Module): """Pooling the sub-graphs in the batched graph.""" - def __init__(self, average: bool = False) -> None: + def __init__(self, *, average: bool = False) -> None: """Args: average (bool): whether to average the features. """ @@ -392,7 +395,12 @@ class GraphAttentionReadOut(nn.Module): """ def __init__( - self, atom_fea_dim: int, num_head: int = 3, hidden_dim: int = 32, average=False + self, + atom_fea_dim: int, + num_head: int = 3, + hidden_dim: int = 32, + *, + average=False, ) -> None: """Initialize the layer. diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 11f9e1e0..2207cd53 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -37,6 +37,7 @@ class CHGNet(nn.Module): def __init__( self, + *, atom_fea_dim: int = 64, bond_fea_dim: int = 64, angle_fea_dim: int = 64, @@ -327,6 +328,7 @@ def n_params(self) -> int: def forward( self, graphs: Sequence[CrystalGraph], + *, task: PredTask = "e", return_site_energies: bool = False, return_atom_feas: bool = False, @@ -381,7 +383,8 @@ def forward( def _compute( self, - g, + g: BatchedGraph, + *, compute_force: bool = False, compute_stress: bool = False, compute_magmom: bool = False, @@ -530,6 +533,7 @@ def _compute( def predict_structure( self, structure: Structure | Sequence[Structure], + *, task: PredTask = "efsm", return_site_energies: bool = False, return_atom_feas: bool = False, @@ -578,6 +582,7 @@ def predict_structure( def predict_graph( self, graph: CrystalGraph | Sequence[CrystalGraph], + *, task: PredTask = "efsm", return_site_energies: bool = False, return_atom_feas: bool = False, @@ -671,7 +676,8 @@ def from_file(cls, path: str, **kwargs) -> CHGNet: @classmethod def load( cls, - model_name="0.3.0", + *, + model_name: str = "0.3.0", use_device: str | None = None, check_cuda_mem: bool = True, verbose: bool = True, @@ -769,6 +775,7 @@ def from_graphs( graphs: Sequence[CrystalGraph], bond_basis_expansion: nn.Module, angle_basis_expansion: nn.Module, + *, compute_stress: bool = False, ) -> BatchedGraph: """Featurize and assemble a list of graphs. diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index b2730c8c..64a223f7 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -33,6 +33,7 @@ class Trainer: def __init__( self, model: CHGNet | None = None, + *, targets: TrainTask = "ef", energy_loss_ratio: float = 1, force_loss_ratio: float = 1, @@ -199,6 +200,7 @@ def train( train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader | None = None, + *, save_dir: str | None = None, save_test_result: bool = False, train_composition_model: bool = False, @@ -353,6 +355,7 @@ def _train(self, train_loader: DataLoader, current_epoch: int) -> dict: def _validate( self, val_loader: DataLoader, + *, is_test: bool = False, test_result_save_path: str | None = None, ) -> dict: @@ -572,6 +575,7 @@ class CombinedLoss(nn.Module): def __init__( self, + *, target_str: str = "ef", criterion: str = "MSE", is_intensive: bool = True, diff --git a/chgnet/utils/common_utils.py b/chgnet/utils/common_utils.py index 8b37e871..e6cfc7ec 100644 --- a/chgnet/utils/common_utils.py +++ b/chgnet/utils/common_utils.py @@ -10,6 +10,7 @@ def determine_device( use_device: str | None = None, + *, check_cuda_mem: bool = True, ): """Determine the device to use for torch model. diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index b748d385..e0f6b7b6 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -16,6 +16,7 @@ def parse_vasp_dir( base_dir: str, + *, check_electronic_convergence: bool = True, save_path: str | None = None, ) -> dict[str, list]: