Skip to content

Commit

Permalink
remove unnecessary float() for offering torch.float16 option (#220)
Browse files Browse the repository at this point in the history
* remove unnecessary float() for offering torch.float16 option

* united test for default data type is added

* fix test_default_dtype.py
  • Loading branch information
kenko911 authored Feb 2, 2024
1 parent 096f512 commit 81ff43e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/matgl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ def set_default_dtype(type_: str = "float", size: int = 32):
torch.set_default_dtype(getattr(torch, f"float{size}"))
else:
raise ValueError("Invalid dtype size")
if type_ == "float" and size == 16 and not torch.cuda.is_available():
raise Exception(
"torch.float16 is not supported for M3GNet because addmm_impl_cpu_ is not implemented"
" for this floating precision, please use size = 32, 64 or using 'cuda' instead !!"
)
6 changes: 5 additions & 1 deletion src/matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def forward(
hessian[iatom] = tmp.view(-1)

if self.calc_stresses:
volume = torch.abs(torch.det(lattice))
volume = (
torch.abs(torch.det(lattice.float())).half()
if matgl.float_th == torch.float16
else torch.abs(torch.det(lattice))
)
sts = -grads[1]
scale = 1.0 / volume * -160.21766208
sts = [i * j for i, j in zip(sts, scale)] if sts.dim() == 3 else [sts * scale]
Expand Down
2 changes: 1 addition & 1 deletion src/matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def compute_pair_vector_and_distance(g: dgl.DGLGraph):
"""
dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["pbc_offshift"]
src_pos = g.ndata["pos"][g.edges()[0]]
bond_vec = (dst_pos - src_pos).float()
bond_vec = dst_pos - src_pos
bond_dist = torch.norm(bond_vec, dim=1)

return bond_vec, bond_dist
Expand Down
2 changes: 0 additions & 2 deletions src/matgl/layers/_graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def _edge_udf(self, edges: dgl.udf.EdgeBatch):
u = edges.src["u"]
eij = edges.data.pop("e")
rbf = edges.data["rbf"]
rbf = rbf.float()
inputs = torch.hstack([vi, vj, eij, u]) if self.include_states else torch.hstack([vi, vj, eij])
mij = {"mij": self.edge_update_func(inputs) * self.edge_weight_func(rbf)}
return mij
Expand Down Expand Up @@ -339,7 +338,6 @@ def node_update_(self, graph: dgl.DGLGraph, state_feat: Tensor) -> Tensor:
dst_id = graph.edges()[1]
vj = graph.ndata["v"][dst_id]
rbf = graph.edata["rbf"]
rbf = rbf.float()
if self.include_states:
u = dgl.broadcast_edges(graph, state_feat)
inputs = torch.hstack([vi, vj, eij, u])
Expand Down
24 changes: 24 additions & 0 deletions tests/test_default_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import numpy as np
import pytest
import torch
from matgl import set_default_dtype


def test_set_default_dtype():
set_default_dtype("float", 32)
assert torch.get_default_dtype() == torch.float32
assert np.dtype("float32") == np.float32


def test_set_default_dtype_invalid_size():
with pytest.raises(ValueError, match="Invalid dtype size"):
set_default_dtype("float", 128)
set_default_dtype("float", 32)


def test_set_default_dtype_exception():
with pytest.raises(Exception, match="torch.float16 is not supported for M3GNet"):
set_default_dtype("float", 16)
set_default_dtype("float", 32)

0 comments on commit 81ff43e

Please sign in to comment.