Skip to content

Commit

Permalink
safety
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Feb 3, 2023
1 parent 32cde59 commit 0072f18
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
5 changes: 4 additions & 1 deletion mace_jax/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def main():
train_loader, valid_loader, test_loader, atomic_energies_dict, r_max = datasets()

model_fn, params, num_message_passing = model(
r_max, atomic_energies_dict, train_loader.graphs, initialize_seed=seed
r_max=r_max,
atomic_energies_dict=atomic_energies_dict,
train_graphs=train_loader.graphs,
initialize_seed=seed,
)

params = reload(params)
Expand Down
13 changes: 7 additions & 6 deletions mace_jax/tools/gin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def __call__(self, node_specie: jnp.ndarray) -> e3nn.IrrepsArray:

@gin.configurable
def model(
*,
r_max: float,
atomic_energies_dict: Dict[int, float] = None,
train_graphs: List[jraph.GraphsTuple] = None,
initialize_seed: Optional[int] = None,
*,
scaling: Callable = None,
atomic_energies: Union[str, np.ndarray, Dict[int, float]] = None,
avg_num_neighbors: float = "average",
Expand Down Expand Up @@ -166,11 +166,12 @@ def model(

# check that num_species is consistent with the dataset
if z_table is None:
for graph in train_graphs:
if not np.all(graph.nodes.species < num_species):
raise ValueError(
f"max(graph.nodes.species)={np.max(graph.nodes.species)} >= num_species={num_species}"
)
if train_graphs is not None:
for graph in train_graphs:
if not np.all(graph.nodes.species < num_species):
raise ValueError(
f"max(graph.nodes.species)={np.max(graph.nodes.species)} >= num_species={num_species}"
)
else:
if max(z_table.zs) >= num_species:
raise ValueError(
Expand Down

0 comments on commit 0072f18

Please sign in to comment.