Skip to content

Commit

Permalink
ruff auto fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed May 19, 2024
1 parent 455f4d8 commit 6c83227
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 24 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.4
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.2.0
rev: v9.3.0
hooks:
- id: eslint
types: [file]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self.keys = np.arange(len(structures))
if shuffle:
random.shuffle(self.keys)
print(f"{len(structures)} structures imported")
print(f"{type(self).__name__} imported {len(structures):,} structures")
self.graph_converter = graph_converter or CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3
)
Expand Down Expand Up @@ -587,7 +587,7 @@ def __init__(
]
if shuffle:
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
print(f"{len(self.data)} MP IDs, {len(self)} structures imported")
self.graph_converter = graph_converter
self.energy_key = energy_key
self.force_key = force_key
Expand Down
1 change: 0 additions & 1 deletion chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def set_isolated_atom_response(
None
"""
self.on_isolated_atoms = on_isolated_atoms
return

def as_dict(self) -> dict[str, str | float]:
"""Save the args of the graph converter."""
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def get_site_energies(self, graphs: list[CrystalGraph]):

def initialize_from(self, dataset: str) -> None:
"""Initialize pre-fitted weights from a dataset."""
if dataset in ["MPtrj", "MPtrj_e"]:
if dataset in {"MPtrj", "MPtrj_e"}:
self.initialize_from_MPtrj()
elif dataset == "MPF":
self.initialize_from_MPF()
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,6 @@ def get_compressibility(self, unit: str = "A^3/eV") -> float:
return 1 / self.bm.b0
if unit == "GPa^-1":
return 1 / self.bm.b0_GPa
if unit in ("Pa^-1", "m^2/N"):
if unit in {"Pa^-1", "m^2/N"}:
return 1 / (self.bm.b0_GPa * 1e9)
raise NotImplementedError("unit has to be one of A^3/eV, GPa^-1 Pa^-1 or m^2/N")
2 changes: 1 addition & 1 deletion chgnet/model/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
Default = True
"""
super().__init__()
if hidden_dim in (None, 0):
if hidden_dim in {None, 0}:
layers = [nn.Dropout(dropout), nn.Linear(input_dim, output_dim, bias=bias)]
elif isinstance(hidden_dim, int):
layers = [
Expand Down
10 changes: 5 additions & 5 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def __init__(
"""
# Store model args for reconstruction
self.model_args = {
k: v
for k, v in locals().items()
if k not in ["self", "__class__", "kwargs"]
key: val
for key, val in locals().items()
if key not in {"self", "__class__", "kwargs"}
}
self.model_args.update(kwargs)
if version:
Expand Down Expand Up @@ -279,7 +279,7 @@ def __init__(
self.read_out_type = "sum"
input_dim = atom_fea_dim
self.pooling = GraphPooling(average=False)
elif read_out in ["attn", "weighted"]:
elif read_out in {"attn", "weighted"}:
self.read_out_type = "attn"
num_heads = kwargs.pop("num_heads", 3)
self.pooling = GraphAttentionReadOut(
Expand All @@ -290,7 +290,7 @@ def __init__(
self.read_out_type = "ave"
input_dim = atom_fea_dim
self.pooling = GraphPooling(average=True)
if kwargs.pop("final_mlp", "MLP") in ["normal", "MLP"]:
if kwargs.pop("final_mlp", "MLP") in {"normal", "MLP"}:
self.mlp = MLP(
input_dim=input_dim,
hidden_dim=mlp_hidden_dims,
Expand Down
14 changes: 7 additions & 7 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self.trainer_args = {
k: v
for k, v in locals().items()
if k not in ["self", "__class__", "model", "kwargs"]
if k not in {"self", "__class__", "model", "kwargs"}
}
self.trainer_args.update(kwargs)

Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
)

# Define learning rate scheduler
if scheduler in ["MultiStepLR", "multistep"]:
if scheduler in {"MultiStepLR", "multistep"}:
scheduler_params = kwargs.pop(
"scheduler_params",
{
Expand All @@ -142,11 +142,11 @@ def __init__(
)
self.scheduler = MultiStepLR(self.optimizer, **scheduler_params)
self.scheduler_type = "multistep"
elif scheduler in ["ExponentialLR", "Exp", "Exponential"]:
elif scheduler in {"ExponentialLR", "Exp", "Exponential"}:
scheduler_params = kwargs.pop("scheduler_params", {"gamma": 0.98})
self.scheduler = ExponentialLR(self.optimizer, **scheduler_params)
self.scheduler_type = "exp"
elif scheduler in ["CosineAnnealingLR", "CosLR", "Cos", "cos"]:
elif scheduler in {"CosineAnnealingLR", "CosLR", "Cos", "cos"}:
scheduler_params = kwargs.pop("scheduler_params", {"decay_fraction": 1e-2})
decay_fraction = scheduler_params.pop("decay_fraction")
self.scheduler = CosineAnnealingLR(
Expand Down Expand Up @@ -481,7 +481,7 @@ def _init_keys(self) -> list[str]:
return [
key
for key in list(inspect.signature(Trainer.__init__).parameters)
if key not in (["self", "model", "kwargs"])
if key not in {"self", "model", "kwargs"}
]

def save(self, filename: str = "training_result.pth.tar") -> None:
Expand Down Expand Up @@ -602,9 +602,9 @@ def __init__(
"""
super().__init__()
# Define loss criterion
if criterion in ["MSE", "mse"]:
if criterion in {"MSE", "mse"}:
self.criterion = nn.MSELoss()
elif criterion in ["MAE", "mae", "l1"]:
elif criterion in {"MAE", "mae", "l1"}:
self.criterion = nn.L1Loss()
elif criterion == "Huber":
self.criterion = nn.HuberLoss(delta=delta)
Expand Down
2 changes: 1 addition & 1 deletion chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def determine_device(
device (str): device name to be passed to model.to(device)
"""
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
if use_device in {"mps", None} and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ ignore = [
"E731", # do not assign a lambda expression, use a def
"EM",
"ERA001", # found commented out code
"FBT001", # Boolean positional argument in function
"FBT002", # Boolean keyword argument in function
"ISC001",
"NPY002", # TODO replace legacy np.random.seed
"PLR", # pylint refactor
Expand All @@ -120,7 +118,7 @@ isort.split-on-trailing-comma = false

[tool.ruff.lint.per-file-ignores]
"site/*" = ["INP001", "S602"]
"tests/*" = ["ANN201", "D103", "INP001", "S101"]
"tests/*" = ["ANN201", "D100", "D103", "FBT001", "FBT002", "INP001", "S101"]
# E402 Module level import not at top of file
"examples/*" = ["E402", "I002", "INP001", "N816", "S101", "T201"]
"chgnet/**/*" = ["T201"]
Expand Down

0 comments on commit 6c83227

Please sign in to comment.