-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
cuda_devices_sorted_by_free_mem()
return [] if `not torch.cuda.is_a…
…vailable()` (#115) * cuda_devices_sorted_by_free_mem() return [] if not torch.cuda.is_available() * refactor test_md_crystal_feas_log * change parse_vasp_dir() empty data exception type to RuntimeError * remove unused mkdir util * type annos
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
AverageMeter, | ||
cuda_devices_sorted_by_free_mem, | ||
mae, | ||
mkdir, | ||
read_json, | ||
write_json, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
import os | ||
|
||
import nvidia_smi | ||
import torch | ||
|
@@ -13,6 +12,9 @@ def cuda_devices_sorted_by_free_mem() -> list[int]: | |
To get the device with the most free memory, use the last list item. | ||
""" | ||
if not torch.cuda.is_available(): | ||
return [] | ||
|
||
free_memories = [] | ||
nvidia_smi.nvmlInit() | ||
device_count = nvidia_smi.nvmlDeviceGetCount() | ||
|
@@ -63,45 +65,28 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor: | |
return torch.mean(torch.abs(target - prediction)) | ||
|
||
|
||
def read_json(fjson: str): | ||
def read_json(filepath: str) -> dict: | ||
"""Read the json file. | ||
Args: | ||
fjson (str): file name of json to read. | ||
filepath (str): file name of json to read. | ||
Returns: | ||
dictionary stored in fjson | ||
""" | ||
with open(fjson) as file: | ||
with open(filepath) as file: | ||
return json.load(file) | ||
|
||
|
||
def write_json(d: dict, fjson: str): | ||
def write_json(dct: dict, filepath: str) -> dict: | ||
"""Write the json file. | ||
Args: | ||
d (dict): dictionary to write | ||
fjson (str): file name of json to write. | ||
dct (dict): dictionary to write | ||
filepath (str): file name of json to write. | ||
Returns: | ||
written dictionary | ||
""" | ||
with open(fjson, "w") as file: | ||
json.dump(d, file) | ||
|
||
|
||
def mkdir(path: str): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
janosh
Author
Collaborator
|
||
"""Make directory. | ||
Args: | ||
path (str): directory name | ||
Returns: | ||
path | ||
""" | ||
folder = os.path.exists(path) | ||
if not folder: | ||
os.makedirs(path) | ||
else: | ||
print("Folder exists") | ||
return path | ||
with open(filepath, "w") as file: | ||
json.dump(dct, file) |
@janosh let's not remove this function. this will cause so many import errors