Skip to content

Commit

Permalink
Add keyword use_device to CHGNet.load()
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziyang HU committed Dec 15, 2023
1 parent 3165efe commit f74dc29
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GraphAttentionReadOut,
GraphPooling,
)
from chgnet.utils import cuda_devices_sorted_by_free_mem

if TYPE_CHECKING:
from chgnet import PredTask
Expand Down Expand Up @@ -668,11 +669,15 @@ def from_file(cls, path, **kwargs):
return CHGNet.from_dict(state["model"], **kwargs)

@classmethod
def load(cls, model_name="0.3.0"):
def load(cls, model_name="0.3.0", use_device: str | None = None):
"""Load pretrained CHGNet model.
Args:
model_name (str, optional): Defaults to "0.3.0".
use_device (str, optional): The device to be used for predictions,
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
Raises:
ValueError: On unknown model_name.
Expand All @@ -685,7 +690,7 @@ def load(cls, model_name="0.3.0"):
if checkpoint_path is None:
raise ValueError(f"Unknown {model_name=}")

return cls.from_file(
model = cls.from_file(
os.path.join(module_dir, checkpoint_path),
# mlp_out_bias=True is set for backward compatible behavior but in rare
# cases causes unphysical jumps in bonding energy. see
Expand All @@ -694,6 +699,19 @@ def load(cls, model_name="0.3.0"):
version=model_name,
)

# Determine the device to use
if use_device == "mps" and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"

# Move the model to the specified device
model = model.to(device)
print(f"CHGNet will run on {device}")
return model


@dataclass
class BatchedGraph:
Expand Down

0 comments on commit f74dc29

Please sign in to comment.