Skip to content

Commit

Permalink
add spice to evaluation routine
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 24, 2024
1 parent 6145598 commit f366b8e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 10 additions & 3 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,22 @@ def run_evaluation(config, num_test: int = None, testing_targets: Sequence[str]

if data_filepath.suffix == '.npz':
loader = data.NpzDataLoaderSparse(input_file=data_filepath)
elif data_filepath.stem[:5].lower() == 'spice':
logging.mlff(f'Found SPICE dataset at {data_filepath}.')
if data_filepath.suffix != '.hdf5':
raise ValueError(
f'Loader assumes that SPICE is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.SpiceDataLoaderSparse(input_file=data_filepath)
else:
loader = data.AseDataLoaderSparse(input_file=data_filepath)

all_data, data_stats = loader.load_all(cutoff=config.model.cutoff)
num_data = len(all_data)

energy_unit = eval(config.data.energy_unit)
length_unit = eval(config.data.length_unit)

all_data, data_stats = loader.load_all(cutoff=config.model.cutoff / length_unit)
num_data = len(all_data)

split_seed = config.data.split_seed
numpy_rng = np.random.RandomState(split_seed)
numpy_rng.shuffle(all_data)
Expand Down
3 changes: 2 additions & 1 deletion mlff/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def fit(
batch_training = graph_to_batch_fn(graph_batch_training)
processed_graphs += batch_training['num_of_non_padded_graphs']
processed_nodes += batch_max_num_nodes - jraph.get_number_of_padding_with_graphs_nodes(graph_batch_training)
# Training data is numpy arrays so we now transform them to jax.numpy arrays.
batch_training = jax.tree_map(jnp.array, batch_training)

# In the first step, initialize the parameters or load from existing checkpoint.
Expand Down Expand Up @@ -287,7 +288,7 @@ def fit(
eval_metrics = {
f'eval_{k}': float(v) for k, v in eval_metrics.items()
}

print(eval_metrics)
# Save checkpoint.
ckpt_mngr.save(
step,
Expand Down

0 comments on commit f366b8e

Please sign in to comment.