Skip to content

Commit

Permalink
compute_average_E0s_from_species
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jul 7, 2023
1 parent 431390e commit 042a21f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
2 changes: 2 additions & 0 deletions mace_jax/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Configuration,
Configurations,
compute_average_E0s,
compute_average_E0s_from_species,
config_from_atoms,
load_from_xyz,
random_train_valid_split,
Expand All @@ -23,6 +24,7 @@
"Configuration",
"Configurations",
"compute_average_E0s",
"compute_average_E0s_from_species",
"config_from_atoms",
"load_from_xyz",
"random_train_valid_split",
Expand Down
45 changes: 42 additions & 3 deletions mace_jax/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def load_from_xyz(

class AtomicNumberTable:
def __init__(self, zs: Sequence[int]):
zs = list(zs)
# integers
assert all(isinstance(z, int) for z in zs)
# unique
assert len(zs) == len(set(zs))
# sorted
assert zs == sorted(zs)

self.zs = zs

def __len__(self) -> int:
Expand All @@ -205,9 +213,15 @@ def __str__(self):
def index_to_z(self, index: int) -> int:
return self.zs[index]

def z_to_index(self, atomic_number: str) -> int:
def z_to_index(self, atomic_number: int) -> int:
return self.zs.index(atomic_number)

def z_to_index_map(self, max_atomic_number: int) -> np.ndarray:
x = np.zeros(max_atomic_number + 1, dtype=np.int32)
for i, z in enumerate(self.zs):
x[z] = i
return x


def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable:
return AtomicNumberTable(sorted(set(zs)))
Expand Down Expand Up @@ -250,20 +264,45 @@ def compute_average_E0s(
return atomic_energies_dict


def compute_average_E0s_from_species(
graphs: List[jraph.GraphsTuple], num_species: int
) -> Dict[int, float]:
"""
Function to compute the average interaction energy of each chemical element
returns dictionary of E0s
"""
len_train = len(graphs)
A = np.zeros((len_train, num_species))
B = np.zeros(len_train)
for i in range(len_train):
B[i] = graphs[i].globals.energy
for j in range(num_species):
A[i, j] = np.count_nonzero(graphs[i].nodes.species == j)
try:
E0s = np.linalg.lstsq(A, B, rcond=None)[0]
except np.linalg.LinAlgError:
logging.warning(
"Failed to compute E0s using least squares regression, using the same for all atoms"
)
E0s = np.zeros(num_species)
return E0s


GraphNodes = namedtuple("Nodes", ["positions", "forces", "species"])
GraphEdges = namedtuple("Edges", ["shifts"])
GraphGlobals = namedtuple("Globals", ["cell", "energy", "stress", "weight"])


def graph_from_configuration(
config: Configuration, cutoff: float, z_map=None
config: Configuration, cutoff: float, z_table: AtomicNumberTable = None
) -> jraph.GraphsTuple:
senders, receivers, shifts = get_neighborhood(
positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell
)
if z_map is None:
if z_table is None:
species = config.atomic_numbers
else:
z_map = z_table.z_to_index_map(max_atomic_number=200)
species = z_map[config.atomic_numbers]

return jraph.GraphsTuple(
Expand Down
10 changes: 9 additions & 1 deletion mace_jax/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def evaluate(
loss_fn: Any,
data_loader: data.GraphDataLoader,
name: str = "Evaluation",
progress_bar: bool = True,
) -> Tuple[float, Dict[str, Any]]:
r"""Evaluate the predictor on the given data loader.
Expand Down Expand Up @@ -143,7 +144,14 @@ def evaluate(
last_cache_size = None

start_time = time.time()
p_bar = tqdm.tqdm(data_loader, desc=name, total=data_loader.approx_length())

p_bar = tqdm.tqdm(
data_loader,
desc=name,
total=data_loader.approx_length(),
disable=not progress_bar,
)

for ref_graph in p_bar:
output = predictor(params, ref_graph)
pred_graph = ref_graph._replace(
Expand Down

0 comments on commit 042a21f

Please sign in to comment.