Skip to content

Commit

Permalink
ruff select = ["ALL"] and fix legacy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed May 19, 2024
1 parent 934b71f commit 7b95af3
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 95 deletions.
14 changes: 7 additions & 7 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
"""
for idx, struct in enumerate(structures):
if not isinstance(struct, Structure):
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
raise TypeError(f"{idx} is not a pymatgen Structure object: {struct}")
for name in "energies forces stresses magmoms structure_ids".split():
labels = locals()[name]
if labels is not None and len(labels) != len(structures):
Expand Down Expand Up @@ -97,7 +97,7 @@ def from_vasp(
save_path: str | None = None,
graph_converter: CrystalGraphConverter | None = None,
shuffle: bool = True,
):
) -> StructureData:
"""Parse VASP output files into structures and labels and feed into the dataset.
Args:
Expand Down Expand Up @@ -586,7 +586,7 @@ def __init__(
elif isinstance(data, dict):
self.data = data
else:
raise ValueError(f"data must be JSON path or dictionary, got {type(data)}")
raise TypeError(f"data must be JSON path or dictionary, got {type(data)}")

self.keys = [
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
Expand All @@ -608,7 +608,7 @@ def __len__(self) -> int:
return len(self.keys)

@functools.cache # Cache loaded structures
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
"""Get one item in the dataset.
Returns:
Expand Down Expand Up @@ -754,7 +754,7 @@ def get_train_val_test_loader(
return train_loader, val_loader, test_loader


def collate_graphs(batch_data: list):
def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tensor]]:
"""Collate of list of (graph, target) into batch data.
Args:
Expand Down Expand Up @@ -791,7 +791,7 @@ def get_train_val_test_loader(
return_test: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
):
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""Randomly partition a dataset into train, val, test loaders.
Args:
Expand Down Expand Up @@ -852,7 +852,7 @@ def get_train_val_test_loader(

def get_loader(
dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True
):
) -> DataLoader:
"""Get a dataloader from a dataset.
Args:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __repr__(self) -> str:
)

@property
def num_isolated_atoms(self):
def num_isolated_atoms(self) -> int:
"""Number of isolated atoms given the atom graph cutoff
Isolated atoms are disconnected nodes in the atom graph
that will not get updated in CHGNet.
Expand Down
52 changes: 28 additions & 24 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class UndirectedEdge(Edge):

__hash__ = Edge.__hash__

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Check if two undirected edges are equal."""
return set(self.nodes) == set(other.nodes) and self.info == other.info

Expand Down Expand Up @@ -178,16 +178,16 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
):
# There is an undirected edge with similar length and only one of
# the directed edges associated has been added
added_DE = self.directed_edges_list[
added_dir_edge = self.directed_edges_list[
undirected_edge.info["directed_edge_index"][0]
]

# See if the DE that's associated to this UDE
# is the reverse of our DE
if added_DE == this_directed_edge:
if added_dir_edge == this_directed_edge:
# Add UDE index to this DE
this_directed_edge.info["undirected_edge_index"] = (
added_DE.info["undirected_edge_index"]
added_dir_edge.info["undirected_edge_index"]
)

# At the center node, draw edge with this DE
Expand Down Expand Up @@ -217,7 +217,7 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
self.nodes[center_index].add_neighbor(neighbor_index, this_directed_edge)
self.directed_edges_list.append(this_directed_edge)

def adjacency_list(self):
def adjacency_list(self) -> tuple[list[list[int]], list[int]]:
"""Get the adjacency list
Return:
graph: the adjacency list
Expand All @@ -240,7 +240,7 @@ def adjacency_list(self):
]
return graph, directed2undirected

def line_graph_adjacency_list(self, cutoff):
def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]:
"""Get the line graph adjacency list.
Args:
Expand All @@ -264,11 +264,12 @@ def line_graph_adjacency_list(self, cutoff):
a list of length = num_undirected_edge that
maps the undirected edge index to one of its directed edges indices
"""
assert len(self.directed_edges_list) == 2 * len(self.undirected_edges_list), (
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 * "
f"number of undirected edges={len(self.directed_edges_list)}!"
f"This indicates directed edges are not complete"
)
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
raise ValueError(
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
f"* number of undirected edges={len(self.directed_edges_list)}!"
f"This indicates directed edges are not complete"
)
line_graph = []
undirected2directed = []

Expand All @@ -285,39 +286,42 @@ def line_graph_adjacency_list(self, cutoff):
# if encountered exception,
# it means after Atom_Graph creation, the UDE has only 1 DE associated
# This exception is not encountered from the develop team's experience
assert len(u_edge.info["directed_edge_index"]) == 2, (
"Did not find 2 Directed_edges !!!"
f"undirected edge {u_edge} has:"
f"edge.info['directed_edge_index'] = "
f"{u_edge.info['directed_edge_index']}"
f"len directed_edges_list = {len(self.directed_edges_list)}"
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
)
if len(u_edge.info["directed_edge_index"]) != 2:
raise ValueError(
"Did not find 2 Directed_edges !!!"
f"undirected edge {u_edge} has:"
f"edge.info['directed_edge_index'] = "
f"{u_edge.info['directed_edge_index']}"
f"len directed_edges_list = {len(self.directed_edges_list)}"
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
)

# This UDE is valid to be considered as a node in Bond_Graph

# Get the two ends (centers) and the two DE associated with this UDE
# DE1 should have center=center1 and DE2 should have center=center2
# We will need to find directed edges with center = center1
# and create angles with DE1, then do the same for center2 and DE2
for center, DE in zip(u_edge.nodes, u_edge.info["directed_edge_index"]):
for center, dir_edge in zip(
u_edge.nodes, u_edge.info["directed_edge_index"]
):
for directed_edges in self.nodes[center].neighbors.values():
for directed_edge in directed_edges:
if directed_edge.index == DE:
if directed_edge.index == dir_edge:
continue
if directed_edge.info["distance"] < cutoff:
line_graph.append(
[
center,
u_edge.index,
DE,
dir_edge,
directed_edge.info["undirected_edge_index"],
directed_edge.index,
]
)
return line_graph, undirected2directed

def undirected2directed(self):
def undirected2directed(self) -> list[int]:
"""The index map from undirected_edge index to one of its directed_edge
index.
"""
Expand All @@ -326,7 +330,7 @@ def undirected2directed(self):
for undirected_edge in self.undirected_edges_list
]

def as_dict(self):
def as_dict(self) -> dict:
"""Return dictionary serialization of a Graph."""
return {
"nodes": self.nodes,
Expand Down
10 changes: 6 additions & 4 deletions chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ class GaussianExpansion(nn.Module):

def __init__(
self,
min: float = 0,
max: float = 5,
min: float = 0, # noqa: A002
max: float = 5, # noqa: A002
step: float = 0.5,
var: float | None = None,
) -> None:
Expand All @@ -138,8 +138,10 @@ def __init__(
var (float): variance in gaussian filter, default to step
"""
super().__init__()
assert min < max
assert max - min > step
if min >= max:
raise ValueError(f"{min=} must be less than {max=}")
if max - min <= step:
raise ValueError(f"{max - min=} must be greater than {step=}")
self.register_buffer("gaussian_centers", torch.arange(min, max + step, step))
self.var = var or step
if self.var <= 0:
Expand Down
15 changes: 8 additions & 7 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, graphs: list[CrystalGraph]) -> Tensor:
composition_feas = self._assemble_graphs(graphs)
return self._get_energy(composition_feas)

def _assemble_graphs(self, graphs: list[CrystalGraph]):
def _assemble_graphs(self, graphs: list[CrystalGraph]) -> Tensor:
"""Assemble a list of graphs into one-hot composition encodings.
Args:
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
self.fc = nn.Linear(max_num_elements, 1, bias=False)
self.fitted = False

def forward(self, graphs: list[CrystalGraph]):
def forward(self, graphs: list[CrystalGraph]) -> Tensor:
"""Get the energy of a list of CrystalGraphs.
Args:
Expand All @@ -108,7 +108,8 @@ def forward(self, graphs: list[CrystalGraph]):
Returns:
energy (tensor)
"""
assert self.fitted is True, "composition model need to be fitted first!"
if not self.fitted:
raise ValueError("composition model needs to be fitted first!")
composition_feas = self._assemble_graphs(graphs)
return self._get_energy(composition_feas)

Expand Down Expand Up @@ -171,7 +172,7 @@ def fit(
self.fc.load_state_dict(state_dict)
self.fitted = True

def _assemble_graphs(self, graphs: list[CrystalGraph]):
def _assemble_graphs(self, graphs: list[CrystalGraph]) -> Tensor:
"""Assemble a list of graphs into one-hot composition encodings
Args:
graphs (list[Tensor]): a list of CrystalGraphs
Expand All @@ -189,7 +190,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]):
composition_feas.append(composition_fea)
return torch.stack(composition_feas, dim=0).float()

def get_site_energies(self, graphs: list[CrystalGraph]):
def get_site_energies(self, graphs: list[CrystalGraph]) -> list[Tensor]:
"""Predict the site energies given a list of CrystalGraphs.
Args:
Expand All @@ -212,7 +213,7 @@ def initialize_from(self, dataset: str) -> None:
else:
raise NotImplementedError(f"{dataset=} not supported yet")

def initialize_from_MPtrj(self) -> None:
def initialize_from_MPtrj(self) -> None: # noqa: N802
"""Initialize pre-fitted weights from MPtrj dataset."""
state_dict = collections.OrderedDict()
state_dict["weight"] = torch.tensor(
Expand Down Expand Up @@ -317,7 +318,7 @@ def initialize_from_MPtrj(self) -> None:
self.is_intensive = True
self.fitted = True

def initialize_from_MPF(self) -> None:
def initialize_from_MPF(self) -> None: # noqa: N802
"""Initialize pre-fitted weights from MPF dataset."""
state_dict = collections.OrderedDict()
state_dict["weight"] = torch.tensor(
Expand Down
8 changes: 5 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(
print(f"CHGNet will run on {self.device}")

@classmethod
def from_file(cls, path: str, use_device: str | None = None, **kwargs):
def from_file(
cls, path: str, use_device: str | None = None, **kwargs
) -> CHGNetCalculator:
"""Load a user's CHGNet model and initialize the Calculator."""
return CHGNetCalculator(
model=CHGNet.from_file(path),
Expand Down Expand Up @@ -402,7 +404,7 @@ def __init__(self, atoms: Atoms) -> None:

def __call__(self) -> None:
"""Record Atoms crystal feature vectors after an MD/relaxation step."""
self.crystal_feature_vectors.append(self.atoms._calc.results["crystal_fea"])
self.crystal_feature_vectors.append(self.atoms._calc.results["crystal_fea"]) # noqa: SLF001

def __len__(self) -> int:
"""Number of recorded steps."""
Expand Down Expand Up @@ -741,7 +743,7 @@ def upper_triangular_cell(self, *, verbose: bool | None = False) -> None:
verbose (bool): Whether to notify user about upper-triangular cell
transformation. Default = False
"""
if not NPT._isuppertriangular(self.atoms.get_cell()):
if not NPT._isuppertriangular(self.atoms.get_cell()): # noqa: SLF001
a, b, c, alpha, beta, gamma = self.atoms.cell.cellpar()
angles = np.radians((alpha, beta, gamma))
sin_a, sin_b, _sin_g = np.sin(angles)
Expand Down
18 changes: 9 additions & 9 deletions chgnet/model/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ def __init__(
)
self.layers = nn.Sequential(*layers)

def forward(self, X: Tensor) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Performs a forward pass through the MLP.
Args:
X (Tensor): a tensor of shape (batch_size, input_dim)
x (Tensor): a tensor of shape (batch_size, input_dim)
Returns:
Tensor: a tensor of shape (batch_size, output_dim)
"""
return self.layers(X)
return self.layers(x)


class GatedMLP(nn.Module):
Expand Down Expand Up @@ -164,21 +164,21 @@ def __init__(
self.bn1 = find_normalization(name=norm, dim=output_dim)
self.bn2 = find_normalization(name=norm, dim=output_dim)

def forward(self, X: Tensor) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Performs a forward pass through the MLP.
Args:
X (Tensor): a tensor of shape (batch_size, input_dim)
x (Tensor): a tensor of shape (batch_size, input_dim)
Returns:
Tensor: a tensor of shape (batch_size, output_dim)
"""
if self.norm is None:
core = self.activation(self.mlp_core(X))
gate = self.sigmoid(self.mlp_gate(X))
core = self.activation(self.mlp_core(x))
gate = self.sigmoid(self.mlp_gate(x))
else:
core = self.activation(self.bn1(self.mlp_core(X)))
gate = self.sigmoid(self.bn2(self.mlp_gate(X)))
core = self.activation(self.bn1(self.mlp_core(x)))
gate = self.sigmoid(self.bn2(self.mlp_gate(x)))
return core * gate


Expand Down
6 changes: 3 additions & 3 deletions chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
use_mlp_out: bool = True,
mlp_out_bias: bool = False,
resnet: bool = True,
gMLP_norm: str | None = None,
gMLP_norm: str | None = None, # noqa: N803
) -> None:
"""Initialize the AtomConv layer.
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
use_mlp_out: bool = True,
mlp_out_bias: bool = False,
resnet=True,
gMLP_norm: str | None = None,
gMLP_norm: str | None = None, # noqa: N803
) -> None:
"""Initialize the BondConv layer.
Expand Down Expand Up @@ -279,7 +279,7 @@ def __init__(
activation: str = "silu",
norm: str | None = None,
resnet: bool = True,
gMLP_norm: str | None = None,
gMLP_norm: str | None = None, # noqa: N803
) -> None:
"""Initialize the AngleUpdate layer.
Expand Down
Loading

0 comments on commit 7b95af3

Please sign in to comment.