From 7b95af3accc389f87911a8ef985527e3a6702abc Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 19 May 2024 17:15:45 -0400 Subject: [PATCH] ruff select = ["ALL"] and fix legacy errors --- chgnet/data/dataset.py | 14 ++++----- chgnet/graph/crystalgraph.py | 2 +- chgnet/graph/graph.py | 52 +++++++++++++++++-------------- chgnet/model/basis.py | 10 +++--- chgnet/model/composition_model.py | 15 ++++----- chgnet/model/dynamics.py | 8 +++-- chgnet/model/functions.py | 18 +++++------ chgnet/model/layers.py | 6 ++-- chgnet/model/model.py | 4 +-- chgnet/trainer/trainer.py | 2 +- chgnet/utils/common_utils.py | 2 +- examples/make_graphs.py | 2 +- pyproject.toml | 38 ++++------------------ 13 files changed, 78 insertions(+), 95 deletions(-) diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 5655bdfd..01296479 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -64,7 +64,7 @@ def __init__( """ for idx, struct in enumerate(structures): if not isinstance(struct, Structure): - raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}") + raise TypeError(f"{idx} is not a pymatgen Structure object: {struct}") for name in "energies forces stresses magmoms structure_ids".split(): labels = locals()[name] if labels is not None and len(labels) != len(structures): @@ -97,7 +97,7 @@ def from_vasp( save_path: str | None = None, graph_converter: CrystalGraphConverter | None = None, shuffle: bool = True, - ): + ) -> StructureData: """Parse VASP output files into structures and labels and feed into the dataset. Args: @@ -586,7 +586,7 @@ def __init__( elif isinstance(data, dict): self.data = data else: - raise ValueError(f"data must be JSON path or dictionary, got {type(data)}") + raise TypeError(f"data must be JSON path or dictionary, got {type(data)}") self.keys = [ (mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct @@ -608,7 +608,7 @@ def __len__(self) -> int: return len(self.keys) @functools.cache # Cache loaded structures - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]: """Get one item in the dataset. Returns: @@ -754,7 +754,7 @@ def get_train_val_test_loader( return train_loader, val_loader, test_loader -def collate_graphs(batch_data: list): +def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tensor]]: """Collate of list of (graph, target) into batch data. Args: @@ -791,7 +791,7 @@ def get_train_val_test_loader( return_test: bool = True, num_workers: int = 0, pin_memory: bool = True, -): +) -> tuple[DataLoader, DataLoader, DataLoader]: """Randomly partition a dataset into train, val, test loaders. Args: @@ -852,7 +852,7 @@ def get_train_val_test_loader( def get_loader( dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True -): +) -> DataLoader: """Get a dataloader from a dataset. Args: diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index cd993e1e..566df036 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -183,7 +183,7 @@ def __repr__(self) -> str: ) @property - def num_isolated_atoms(self): + def num_isolated_atoms(self) -> int: """Number of isolated atoms given the atom graph cutoff Isolated atoms are disconnected nodes in the atom graph that will not get updated in CHGNet. diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index 25383348..d14291e1 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -66,7 +66,7 @@ class UndirectedEdge(Edge): __hash__ = Edge.__hash__ - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Check if two undirected edges are equal.""" return set(self.nodes) == set(other.nodes) and self.info == other.info @@ -178,16 +178,16 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None: ): # There is an undirected edge with similar length and only one of # the directed edges associated has been added - added_DE = self.directed_edges_list[ + added_dir_edge = self.directed_edges_list[ undirected_edge.info["directed_edge_index"][0] ] # See if the DE that's associated to this UDE # is the reverse of our DE - if added_DE == this_directed_edge: + if added_dir_edge == this_directed_edge: # Add UDE index to this DE this_directed_edge.info["undirected_edge_index"] = ( - added_DE.info["undirected_edge_index"] + added_dir_edge.info["undirected_edge_index"] ) # At the center node, draw edge with this DE @@ -217,7 +217,7 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None: self.nodes[center_index].add_neighbor(neighbor_index, this_directed_edge) self.directed_edges_list.append(this_directed_edge) - def adjacency_list(self): + def adjacency_list(self) -> tuple[list[list[int]], list[int]]: """Get the adjacency list Return: graph: the adjacency list @@ -240,7 +240,7 @@ def adjacency_list(self): ] return graph, directed2undirected - def line_graph_adjacency_list(self, cutoff): + def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]: """Get the line graph adjacency list. Args: @@ -264,11 +264,12 @@ def line_graph_adjacency_list(self, cutoff): a list of length = num_undirected_edge that maps the undirected edge index to one of its directed edges indices """ - assert len(self.directed_edges_list) == 2 * len(self.undirected_edges_list), ( - f"Error: number of directed edges={len(self.directed_edges_list)} != 2 * " - f"number of undirected edges={len(self.directed_edges_list)}!" - f"This indicates directed edges are not complete" - ) + if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list): + raise ValueError( + f"Error: number of directed edges={len(self.directed_edges_list)} != 2 " + f"* number of undirected edges={len(self.directed_edges_list)}!" + f"This indicates directed edges are not complete" + ) line_graph = [] undirected2directed = [] @@ -285,14 +286,15 @@ def line_graph_adjacency_list(self, cutoff): # if encountered exception, # it means after Atom_Graph creation, the UDE has only 1 DE associated # This exception is not encountered from the develop team's experience - assert len(u_edge.info["directed_edge_index"]) == 2, ( - "Did not find 2 Directed_edges !!!" - f"undirected edge {u_edge} has:" - f"edge.info['directed_edge_index'] = " - f"{u_edge.info['directed_edge_index']}" - f"len directed_edges_list = {len(self.directed_edges_list)}" - f"len undirected_edges_list = {len(self.undirected_edges_list)}" - ) + if len(u_edge.info["directed_edge_index"]) != 2: + raise ValueError( + "Did not find 2 Directed_edges !!!" + f"undirected edge {u_edge} has:" + f"edge.info['directed_edge_index'] = " + f"{u_edge.info['directed_edge_index']}" + f"len directed_edges_list = {len(self.directed_edges_list)}" + f"len undirected_edges_list = {len(self.undirected_edges_list)}" + ) # This UDE is valid to be considered as a node in Bond_Graph @@ -300,24 +302,26 @@ def line_graph_adjacency_list(self, cutoff): # DE1 should have center=center1 and DE2 should have center=center2 # We will need to find directed edges with center = center1 # and create angles with DE1, then do the same for center2 and DE2 - for center, DE in zip(u_edge.nodes, u_edge.info["directed_edge_index"]): + for center, dir_edge in zip( + u_edge.nodes, u_edge.info["directed_edge_index"] + ): for directed_edges in self.nodes[center].neighbors.values(): for directed_edge in directed_edges: - if directed_edge.index == DE: + if directed_edge.index == dir_edge: continue if directed_edge.info["distance"] < cutoff: line_graph.append( [ center, u_edge.index, - DE, + dir_edge, directed_edge.info["undirected_edge_index"], directed_edge.index, ] ) return line_graph, undirected2directed - def undirected2directed(self): + def undirected2directed(self) -> list[int]: """The index map from undirected_edge index to one of its directed_edge index. """ @@ -326,7 +330,7 @@ def undirected2directed(self): for undirected_edge in self.undirected_edges_list ] - def as_dict(self): + def as_dict(self) -> dict: """Return dictionary serialization of a Graph.""" return { "nodes": self.nodes, diff --git a/chgnet/model/basis.py b/chgnet/model/basis.py index 10c08a11..1c024b89 100644 --- a/chgnet/model/basis.py +++ b/chgnet/model/basis.py @@ -123,8 +123,8 @@ class GaussianExpansion(nn.Module): def __init__( self, - min: float = 0, - max: float = 5, + min: float = 0, # noqa: A002 + max: float = 5, # noqa: A002 step: float = 0.5, var: float | None = None, ) -> None: @@ -138,8 +138,10 @@ def __init__( var (float): variance in gaussian filter, default to step """ super().__init__() - assert min < max - assert max - min > step + if min >= max: + raise ValueError(f"{min=} must be less than {max=}") + if max - min <= step: + raise ValueError(f"{max - min=} must be greater than {step=}") self.register_buffer("gaussian_centers", torch.arange(min, max + step, step)) self.var = var or step if self.var <= 0: diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 65edfff6..4a7e0e99 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -63,7 +63,7 @@ def forward(self, graphs: list[CrystalGraph]) -> Tensor: composition_feas = self._assemble_graphs(graphs) return self._get_energy(composition_feas) - def _assemble_graphs(self, graphs: list[CrystalGraph]): + def _assemble_graphs(self, graphs: list[CrystalGraph]) -> Tensor: """Assemble a list of graphs into one-hot composition encodings. Args: @@ -99,7 +99,7 @@ def __init__( self.fc = nn.Linear(max_num_elements, 1, bias=False) self.fitted = False - def forward(self, graphs: list[CrystalGraph]): + def forward(self, graphs: list[CrystalGraph]) -> Tensor: """Get the energy of a list of CrystalGraphs. Args: @@ -108,7 +108,8 @@ def forward(self, graphs: list[CrystalGraph]): Returns: energy (tensor) """ - assert self.fitted is True, "composition model need to be fitted first!" + if not self.fitted: + raise ValueError("composition model needs to be fitted first!") composition_feas = self._assemble_graphs(graphs) return self._get_energy(composition_feas) @@ -171,7 +172,7 @@ def fit( self.fc.load_state_dict(state_dict) self.fitted = True - def _assemble_graphs(self, graphs: list[CrystalGraph]): + def _assemble_graphs(self, graphs: list[CrystalGraph]) -> Tensor: """Assemble a list of graphs into one-hot composition encodings Args: graphs (list[Tensor]): a list of CrystalGraphs @@ -189,7 +190,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]): composition_feas.append(composition_fea) return torch.stack(composition_feas, dim=0).float() - def get_site_energies(self, graphs: list[CrystalGraph]): + def get_site_energies(self, graphs: list[CrystalGraph]) -> list[Tensor]: """Predict the site energies given a list of CrystalGraphs. Args: @@ -212,7 +213,7 @@ def initialize_from(self, dataset: str) -> None: else: raise NotImplementedError(f"{dataset=} not supported yet") - def initialize_from_MPtrj(self) -> None: + def initialize_from_MPtrj(self) -> None: # noqa: N802 """Initialize pre-fitted weights from MPtrj dataset.""" state_dict = collections.OrderedDict() state_dict["weight"] = torch.tensor( @@ -317,7 +318,7 @@ def initialize_from_MPtrj(self) -> None: self.is_intensive = True self.fitted = True - def initialize_from_MPF(self) -> None: + def initialize_from_MPF(self) -> None: # noqa: N802 """Initialize pre-fitted weights from MPF dataset.""" state_dict = collections.OrderedDict() state_dict["weight"] = torch.tensor( diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 38b34cde..b41aba6a 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -94,7 +94,9 @@ def __init__( print(f"CHGNet will run on {self.device}") @classmethod - def from_file(cls, path: str, use_device: str | None = None, **kwargs): + def from_file( + cls, path: str, use_device: str | None = None, **kwargs + ) -> CHGNetCalculator: """Load a user's CHGNet model and initialize the Calculator.""" return CHGNetCalculator( model=CHGNet.from_file(path), @@ -402,7 +404,7 @@ def __init__(self, atoms: Atoms) -> None: def __call__(self) -> None: """Record Atoms crystal feature vectors after an MD/relaxation step.""" - self.crystal_feature_vectors.append(self.atoms._calc.results["crystal_fea"]) + self.crystal_feature_vectors.append(self.atoms._calc.results["crystal_fea"]) # noqa: SLF001 def __len__(self) -> int: """Number of recorded steps.""" @@ -741,7 +743,7 @@ def upper_triangular_cell(self, *, verbose: bool | None = False) -> None: verbose (bool): Whether to notify user about upper-triangular cell transformation. Default = False """ - if not NPT._isuppertriangular(self.atoms.get_cell()): + if not NPT._isuppertriangular(self.atoms.get_cell()): # noqa: SLF001 a, b, c, alpha, beta, gamma = self.atoms.cell.cellpar() angles = np.radians((alpha, beta, gamma)) sin_a, sin_b, _sin_g = np.sin(angles) diff --git a/chgnet/model/functions.py b/chgnet/model/functions.py index 764b4ccc..6e96fc07 100644 --- a/chgnet/model/functions.py +++ b/chgnet/model/functions.py @@ -94,16 +94,16 @@ def __init__( ) self.layers = nn.Sequential(*layers) - def forward(self, X: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """Performs a forward pass through the MLP. Args: - X (Tensor): a tensor of shape (batch_size, input_dim) + x (Tensor): a tensor of shape (batch_size, input_dim) Returns: Tensor: a tensor of shape (batch_size, output_dim) """ - return self.layers(X) + return self.layers(x) class GatedMLP(nn.Module): @@ -164,21 +164,21 @@ def __init__( self.bn1 = find_normalization(name=norm, dim=output_dim) self.bn2 = find_normalization(name=norm, dim=output_dim) - def forward(self, X: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """Performs a forward pass through the MLP. Args: - X (Tensor): a tensor of shape (batch_size, input_dim) + x (Tensor): a tensor of shape (batch_size, input_dim) Returns: Tensor: a tensor of shape (batch_size, output_dim) """ if self.norm is None: - core = self.activation(self.mlp_core(X)) - gate = self.sigmoid(self.mlp_gate(X)) + core = self.activation(self.mlp_core(x)) + gate = self.sigmoid(self.mlp_gate(x)) else: - core = self.activation(self.bn1(self.mlp_core(X))) - gate = self.sigmoid(self.bn2(self.mlp_gate(X))) + core = self.activation(self.bn1(self.mlp_core(x))) + gate = self.sigmoid(self.bn2(self.mlp_gate(x))) return core * gate diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index ebb055bb..d31fa0b4 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -27,7 +27,7 @@ def __init__( use_mlp_out: bool = True, mlp_out_bias: bool = False, resnet: bool = True, - gMLP_norm: str | None = None, + gMLP_norm: str | None = None, # noqa: N803 ) -> None: """Initialize the AtomConv layer. @@ -153,7 +153,7 @@ def __init__( use_mlp_out: bool = True, mlp_out_bias: bool = False, resnet=True, - gMLP_norm: str | None = None, + gMLP_norm: str | None = None, # noqa: N803 ) -> None: """Initialize the BondConv layer. @@ -279,7 +279,7 @@ def __init__( activation: str = "silu", norm: str | None = None, resnet: bool = True, - gMLP_norm: str | None = None, + gMLP_norm: str | None = None, # noqa: N803 ) -> None: """Initialize the AngleUpdate layer. diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 2207cd53..e11b28bd 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -62,7 +62,7 @@ def __init__( graph_converter_algorithm: Literal["legacy", "fast"] = "fast", cutoff_coeff: int = 8, learnable_rbf: bool = True, - gMLP_norm: str | None = "layer", + gMLP_norm: str | None = "layer", # noqa: N803 readout_norm: str | None = "layer", version: str | None = None, **kwargs, @@ -613,7 +613,7 @@ def predict_graph( magneton mu_B """ if not isinstance(graph, (CrystalGraph, Sequence)): - raise ValueError( + raise TypeError( f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs" ) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 64a223f7..c2cdb379 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -475,7 +475,7 @@ def get_best_model(self) -> CHGNet: """Get best model recorded in the trainer.""" if self.best_model is None: raise RuntimeError("the model needs to be trained first") - MAE = min(self.training_history["e"]["val"]) + MAE = min(self.training_history["e"]["val"]) # noqa: N806 print(f"Best model has val {MAE =:.4}") return self.best_model diff --git a/chgnet/utils/common_utils.py b/chgnet/utils/common_utils.py index e6cfc7ec..a2d94b07 100644 --- a/chgnet/utils/common_utils.py +++ b/chgnet/utils/common_utils.py @@ -12,7 +12,7 @@ def determine_device( use_device: str | None = None, *, check_cuda_mem: bool = True, -): +) -> str: """Determine the device to use for torch model. Args: diff --git a/examples/make_graphs.py b/examples/make_graphs.py index 71326cda..ba209f43 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -71,7 +71,7 @@ def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool: def make_partition( - data, graph_dir, train_ratio=0.8, val_ratio=0.1, partition_with_frame=False + data, graph_dir, train_ratio=0.8, val_ratio=0.1, *, partition_with_frame=False ) -> None: """Make a train val test partition.""" random.seed(42) diff --git a/pyproject.toml b/pyproject.toml index 65851527..ab1e324c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,41 +52,12 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } target-version = "py39" include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] [tool.ruff.lint] -select = [ - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "D", # pydocstyle - "E", # pycodestyle error - "EXE", # flake8-executable - "F", # pyflakes - "FA", # flake8-future-annotations - "FLY", # flynt - "I", # isort - "ICN", # flake8-import-conventions - "ISC", # flake8-implicit-str-concat - "PD", # pandas-vet - "PERF", # perflint - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PYI", # flakes8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-raise - "RUF", # Ruff-specific rules - "SIM", # flake8-simplify - "SLOT", # flakes8-slot - "T201", - "TCH", # flake8-type-checking - "TID", # tidy imports - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 -] +select = ["ALL"] ignore = [ "ANN001", # TODO add missing type annotations "ANN003", + "ANN101", + "ANN102", "B019", # Use of functools.lru_cache on methods can lead to memory leaks "BLE001", "C408", # unnecessary-collection-call @@ -108,9 +79,12 @@ ignore = [ "PT013", # pytest-incorrect-pytest-import "PT019", # pytest-fixture-param-without-value "PTH", # prefer Path to os.path + "S108", "S301", # pickle can be unsafe "S310", + "S311", "TRY003", + "TRY300", ] pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"]