Skip to content

Commit

Permalink
metrics can be written to data file in evaluate_so3krates_sparse_on
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 30, 2024
1 parent 5690a82 commit 8f4d8c8
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 204 deletions.
43 changes: 41 additions & 2 deletions mlff/CLI/run_evaluation_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from mlff.config import from_config
from ml_collections import config_dict
import pathlib
import logging
from functools import partial, partialmethod

logging.MLFF = 35
logging.addLevelName(logging.MLFF, 'MLFF')
logging.Logger.trace = partialmethod(logging.Logger.log, logging.MLFF)
logging.mlff = partial(logging.log, logging.MLFF)


def evaluate_so3krates_sparse_on():
Expand Down Expand Up @@ -58,9 +65,26 @@ def evaluate_so3krates_sparse_on():
default=None,
help='Number of test points to use. If given, the first num_test points in the dataset are used for evaluation.'
)

parser.add_argument('--write_batch_metrics_to',
type=str,
required=False,
default=None,
help='Path to csv file where metrics per batch should be written to. If not given, '
'batch metrics are not written to a file. Note, that the metrics are written per batch, '
'so one-to-one correspondence to the original data set can only be achieved when '
'`batch_max_num_nodes = 2` which allows one graph per batch, following the `jraph` logic '
'that one graph in used as padding graph.'
)
args = parser.parse_args()

if args.num_test is not None and args.write_batch_metrics_to is not None:
raise ValueError(
f'--num_test={args.num_test} is not `None` such that data is randomly sub-sampled from {args.filepath}. '
f'At the same time `--write_batch_metrics_to={args.write_batch_metrics_to}` is specified. Due to the '
f'random subsampling of data, there is no one-to-one correspondence between the lines in the file the '
f'metrics are written to and the indices of the data point, so we raise an error here for security.'
)

workdir = pathlib.Path(args.workdir).expanduser().resolve()

with open(workdir / 'hyperparameters.json', 'r') as fp:
Expand All @@ -77,7 +101,22 @@ def evaluate_so3krates_sparse_on():
cfg.training.batch_max_num_edges = args.max_num_edges
cfg.training.batch_max_num_nodes = args.max_num_nodes

metrics = from_config.run_evaluation(config=cfg, num_test=args.num_test, pick_idx=None)
# Expand and resolve path for writing metrics.
write_batch_metrics_to = pathlib.Path(
args.write_batch_metrics_to
).expanduser().resolve() if args.write_batch_metrics_to is not None else None

if write_batch_metrics_to.suffix == '.csv':
pass
else:
write_batch_metrics_to = f'{write_batch_metrics_to}.csv'

metrics = from_config.run_evaluation(
config=cfg,
num_test=args.num_test,
pick_idx=None,
write_batch_metrics_to=write_batch_metrics_to
)
print(metrics)


Expand Down
8 changes: 7 additions & 1 deletion mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def run_evaluation(
config,
num_test: int = None,
testing_targets: Sequence[str] = None,
pick_idx: np.ndarray = None
pick_idx: np.ndarray = None,
write_batch_metrics_to: str = None
):
"""Run evaluation, given the config and additional args.
Expand All @@ -275,6 +276,10 @@ def run_evaluation(
testing_targets (): Targets used for computing metrics. Defaults to the ones found in
config.training.loss_weights.
pick_idx (): Indices to evaluate the model on. Loads only the data at the given indices.
write_batch_metrics_to (str): Path to file where metrics per batch should be written to. If not given,
batch metrics are not written to a file. Note, that the metrics are written per batch, so one-to-one
correspondence to the original data set can only be achieved when `batch_max_num_nodes = 2` which allows
one graph per batch, following the `jraph` logic that one graph in used as padding graph.
Returns:
The metrics on `testing_targets`.
Expand Down Expand Up @@ -359,4 +364,5 @@ def run_evaluation(
batch_max_num_nodes=config.training.batch_max_num_nodes,
batch_max_num_edges=config.training.batch_max_num_edges,
batch_max_num_graphs=config.training.batch_max_num_graphs,
write_batch_metrics_to=write_batch_metrics_to
)
192 changes: 0 additions & 192 deletions mlff/training/run_sparse.py

This file was deleted.

30 changes: 21 additions & 9 deletions mlff/utils/evaluation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.numpy as jnp
import jraph
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import Any
from mlff.nn.stacknet.observable_function_sparse import get_energy_and_force_fn_sparse
Expand All @@ -17,7 +18,8 @@ def evaluate(
testing_targets,
batch_max_num_nodes,
batch_max_num_edges,
batch_max_num_graphs
batch_max_num_graphs,
write_batch_metrics_to: str = None
):
"""Evaluate a model given its params on the testing data.
Expand All @@ -31,6 +33,10 @@ def evaluate(
batch_max_num_nodes (): Maximal number of nodes per batch.
batch_max_num_edges (): Maximal number of edges per batch.
batch_max_num_graphs (): Maximal number of graphs oer batch.
write_batch_metrics_to (str): Path to file where metrics per batch should be written to. If not given,
batch metrics are not written to a file. Note, that the metrics are written per batch, so one-to-one
correspondence to the original data set can only be achieved when `batch_max_num_nodes = 2` which allows
one graph per batch, following the `jraph` logic that one graph in used as padding graph.
Returns:
The metrics on testing data.
Expand All @@ -52,7 +58,7 @@ def evaluate(
**{f'{t}_{m}': clu_metrics.Average.from_output(f'{t}_{m}') for (t, m) in it.product(testing_targets, ('mae', 'mse'))})

# Start iteration over validation batches.
testing_metrics = []
row_metrics = []
test_metrics: Any = None
for graph_batch_testing in tqdm(iterator_testing):
batch_testing = graph_to_batch_fn(graph_batch_testing)
Expand All @@ -73,14 +79,20 @@ def evaluate(
elif t == 'stress':
msk = graph_mask
else:
raise ValueError(f"Evaluate not implemented for target={t}.")
raise ValueError(
f"Evaluate not implemented for target={t}."
)

metrics_dict[f"{t}_mae"] = calculate_mae(
y_predicted=output_prediction[t], y_true=batch_testing[t], msk=msk
),
)
metrics_dict[f"{t}_mse"] = calculate_mse(
y_predicted=output_prediction[t], y_true=batch_testing[t], msk=msk
),
)

# Track the metrics per batch if they are written to file.
if write_batch_metrics_to is not None:
row_metrics += [jax.device_get(metrics_dict)]

test_metrics = (
test_collection.single_from_model_output(**metrics_dict)
Expand All @@ -89,10 +101,10 @@ def evaluate(
)
test_metrics = test_metrics.compute()

# testing_metrics_np = jax.device_get(testing_metrics)
# testing_metrics_np = {
# k: np.mean([m[k] for m in testing_metrics_np]) for k in testing_metrics_np[0]
# }
if write_batch_metrics_to:
df = pd.DataFrame(row_metrics)
with open(write_batch_metrics_to, mode='w') as fp:
df.to_csv(fp)

test_metrics = {
f'test_{k}': float(v) for k, v in test_metrics.items()
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"optax",
"orbax-checkpoint",
"portpicker",
"pandas",
# 'tensorflow',
"scikit-learn",
"ase",
Expand Down

0 comments on commit 8f4d8c8

Please sign in to comment.