Skip to content

Commit

Permalink
cuda_devices_sorted_by_free_mem() return [] if `not torch.cuda.is_a…
Browse files Browse the repository at this point in the history
…vailable()` (#115)

* cuda_devices_sorted_by_free_mem() return [] if not torch.cuda.is_available()

* refactor test_md_crystal_feas_log

* change parse_vasp_dir() empty data exception type to RuntimeError

* remove unused mkdir util

* type annos
  • Loading branch information
janosh authored Jan 25, 2024
1 parent 544eb9e commit 007f682
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 59 deletions.
2 changes: 1 addition & 1 deletion chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __len__(self) -> int:
return len(self.crystal_feature_vectors)

def save(self, filename: str) -> None:
"""Save the crystal feature vectors to file."""
"""Save the crystal feature vectors to filename in pickle format."""
out_pkl = {"crystal_feas": self.crystal_feature_vectors}
with open(filename, "wb") as file:
pickle.dump(out_pkl, file)
Expand Down
14 changes: 7 additions & 7 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,31 +645,31 @@ def predict_graph(

return predictions[0] if len(graphs) == 1 else predictions

def as_dict(self):
def as_dict(self) -> dict:
"""Return the CHGNet weights and args in a dictionary."""
return {"state_dict": self.state_dict(), "model_args": self.model_args}

def todict(self):
def todict(self) -> dict:
"""Needed for ASE JSON serialization when saving CHGNet potential to
trajectory file (https://github.com/CederGroupHub/chgnet/issues/48).
"""
return {"model_name": type(self).__name__, "model_args": self.model_args}

@classmethod
def from_dict(cls, dict, **kwargs):
def from_dict(cls, dct: dict, **kwargs) -> CHGNet:
"""Build a CHGNet from a saved dictionary."""
chgnet = CHGNet(**dict["model_args"], **kwargs)
chgnet.load_state_dict(dict["state_dict"])
chgnet = CHGNet(**dct["model_args"], **kwargs)
chgnet.load_state_dict(dct["state_dict"])
return chgnet

@classmethod
def from_file(cls, path, **kwargs):
def from_file(cls, path, **kwargs) -> CHGNet:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
return CHGNet.from_dict(state["model"], **kwargs)

@classmethod
def load(cls, model_name="0.3.0", use_device: str | None = None):
def load(cls, model_name="0.3.0", use_device: str | None = None) -> CHGNet:
"""Load pretrained CHGNet model.
Args:
Expand Down
12 changes: 6 additions & 6 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _validate(
print(message)
return {k: round(mae_error.avg, 6) for k, mae_error in mae_errors.items()}

def get_best_model(self):
def get_best_model(self) -> CHGNet:
"""Get best model recorded in the trainer."""
if self.best_model is None:
raise RuntimeError("the model needs to be trained first")
Expand All @@ -476,7 +476,7 @@ def get_best_model(self):
return self.best_model

@property
def _init_keys(self):
def _init_keys(self) -> list[str]:
return [
key
for key in list(inspect.signature(Trainer.__init__).parameters)
Expand Down Expand Up @@ -680,23 +680,23 @@ def forward(
# Mag
if self.mag_loss_ratio != 0 and "m" in targets:
mag_preds, mag_targets = [], []
m_MAE_size = 0
m_mae_size = 0
for mag_pred, mag_target in zip(prediction["m"], targets["m"]):
# exclude structures without magmom labels
if mag_target is not None:
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_MAE_size += mag_target.shape[0]
m_mae_size += mag_target.shape[0]
if mag_targets != []:
mag_preds = torch.cat(mag_preds, dim=0)
mag_targets = torch.cat(mag_targets, dim=0)
out["loss"] += self.mag_loss_ratio * self.criterion(
mag_targets, mag_preds
)
out["m_MAE"] = mae(mag_targets, mag_preds)
out["m_MAE_size"] = m_MAE_size
out["m_MAE_size"] = m_mae_size
else:
out["m_MAE"] = torch.zeros([1])
out["m_MAE_size"] = m_MAE_size
out["m_MAE_size"] = m_mae_size

return out
1 change: 0 additions & 1 deletion chgnet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
AverageMeter,
cuda_devices_sorted_by_free_mem,
mae,
mkdir,
read_json,
write_json,
)
Expand Down
37 changes: 11 additions & 26 deletions chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import os

import nvidia_smi
import torch
Expand All @@ -13,6 +12,9 @@ def cuda_devices_sorted_by_free_mem() -> list[int]:
To get the device with the most free memory, use the last list item.
"""
if not torch.cuda.is_available():
return []

free_memories = []
nvidia_smi.nvmlInit()
device_count = nvidia_smi.nvmlDeviceGetCount()
Expand Down Expand Up @@ -63,45 +65,28 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor:
return torch.mean(torch.abs(target - prediction))


def read_json(fjson: str):
def read_json(filepath: str) -> dict:
"""Read the json file.
Args:
fjson (str): file name of json to read.
filepath (str): file name of json to read.
Returns:
dictionary stored in fjson
"""
with open(fjson) as file:
with open(filepath) as file:
return json.load(file)


def write_json(d: dict, fjson: str):
def write_json(dct: dict, filepath: str) -> dict:
"""Write the json file.
Args:
d (dict): dictionary to write
fjson (str): file name of json to write.
dct (dict): dictionary to write
filepath (str): file name of json to write.
Returns:
written dictionary
"""
with open(fjson, "w") as file:
json.dump(d, file)


def mkdir(path: str):

This comment has been minimized.

Copy link
@bowen-bd

bowen-bd Jan 25, 2024

Collaborator

@janosh let's not remove this function. this will cause so many import errors

This comment has been minimized.

Copy link
@janosh

janosh Jan 25, 2024

Author Collaborator

are you referring to your own code? i would just replace it with

os.makedirs(path, exist_ok=True)

i doubt a lot of 3rd party users imported this util? but could be wrong there

This comment has been minimized.

Copy link
@bowen-bd

bowen-bd Jan 25, 2024

Collaborator

Yes, a lot of my old codes used this function.

I appreciate your suggestion and the code is cleaner.
But It takes time to fix all the old codes if I need to rerun some in the future.

Leaving this function here will cause no harm

This comment has been minimized.

Copy link
@janosh

janosh Jan 25, 2024

Author Collaborator

true: 44658ec

This comment has been minimized.

Copy link
@bowen-bd

bowen-bd Jan 25, 2024

Collaborator

thanks!

"""Make directory.
Args:
path (str): directory name
Returns:
path
"""
folder = os.path.exists(path)
if not folder:
os.makedirs(path)
else:
print("Folder exists")
return path
with open(filepath, "w") as file:
json.dump(dct, file)
4 changes: 2 additions & 2 deletions chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def parse_vasp_dir(
Args:
file_root (str): the directory of the VASP calculation outputs
check_electronic_convergence (bool): if set to True, this function will raise
Exception to VASP calculation that did not achieve
Exception to VASP calculation that did not achieve electronic convergence.
"""
try:
oszicar = Oszicar(f"{file_root}/OSZICAR")
Expand Down Expand Up @@ -148,7 +148,7 @@ def parse_vasp_dir(
dataset["stress"].append(ionic_step["stress"])

if dataset["uncorrected_total_energy"] == []:
raise Exception(f"No data parsed from {file_root}!")
raise RuntimeError(f"No data parsed from {file_root}!")

return dataset

Expand Down
3 changes: 2 additions & 1 deletion examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool:
try:
graph = data.graph_converter(struct, graph_id=graph_id, mp_id=mp_id)
torch.save(graph, os.path.join(graph_dir, f"{graph_id}.pt"))
return dct
except Exception:
return False
else:
return dct


def make_partition(
Expand Down
13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,20 @@ select = [
]
ignore = [
"ANN001", # TODO add missing type annotations
"ANN003",
"ANN101", # Missing type annotation for self in method
"ANN102",
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
"BLE001",
"C408", # unnecessary-collection-call
"C901", # function is too complex
"COM812", # trailing comma missing
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"EM",
"ERA001", # found commented out code
"FBT001", # Boolean positional argument in function
"FBT002", # Boolean keyword argument in function
"NPY002", # TODO replace legacy np.random.seed
Expand All @@ -101,15 +108,19 @@ ignore = [
"PT013", # pytest-incorrect-pytest-import
"PT019", # pytest-fixture-param-without-value
"PTH", # prefer Path to os.path
"S301", # pickle can be unsafe
"S310",
"TRY003",
]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false

[tool.ruff.per-file-ignores]
"site/*" = ["INP001", "S602"]
"tests/*" = ["ANN201", "D103", "INP001", "S101"]
# E402 Module level import not at top of file
"examples/*" = ["E402", "I002", "T201"]
"examples/*" = ["E402", "I002", "INP001", "N816", "S101", "T201"]
"chgnet/**/*" = ["T201"]
"__init__.py" = ["F401"]

Expand Down
11 changes: 4 additions & 7 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,7 @@ def test_md_npt_nose_hoover(tmp_path: Path, monkeypatch: MonkeyPatch):
assert_allclose(logs, ref, rtol=1e-2, atol=1e-7)


def test_md_crystal_feas_log(
tmp_path: Path,
monkeypatch: MonkeyPatch,
):
def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
monkeypatch.chdir(tmp_path) # run MD in temporary directory

md = MolecularDynamics(
Expand All @@ -293,13 +290,13 @@ def test_md_crystal_feas_log(
timestep=2, # in fs
trajectory="md_out.traj",
logfile="md_out.log",
crystal_feas_logfile="md_crystal_feas.p",
crystal_feas_logfile=(logfile := "md_crystal_feas.pkl"),
loginterval=1,
)
md.run(100)

assert os.path.isfile("md_crystal_feas.p")
with open("md_crystal_feas.p", "rb") as file:
assert os.path.isfile(logfile)
with open(logfile, "rb") as file:
data = pickle.load(file)

crystal_feas = data["crystal_feas"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def test_predict_batched_structures() -> None:
out = model.predict_structure(structs, return_site_energies=True)
assert len(out) == len(structs)
for preds in out:
for property in ["e", "f", "s", "m", "site_energies"]:
assert preds[property] == pytest.approx(
pristine_prediction[property], rel=1e-3, abs=1e-3
for prop in ["e", "f", "s", "m", "site_energies"]:
assert preds[prop] == pytest.approx(
pristine_prediction[prop], rel=1e-3, abs=1e-3
)


Expand Down
9 changes: 5 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

def test_cuda_devices_sorted_by_free_mem():
# can't test this any better on CPU
# but good to check it doesn't crash on CPU
if torch.cuda.is_available() is False:
assert torch.cuda.device_count() == 0
# but good to check it at least doesn't crash on CPU
lst = cuda_devices_sorted_by_free_mem()
if torch.cuda.is_available():
assert len(lst) > 0
else:
assert len(cuda_devices_sorted_by_free_mem()) > 0
assert lst == []


@pytest.mark.parametrize("key", ["final_magmom", "magmom"])
Expand Down

0 comments on commit 007f682

Please sign in to comment.