Skip to content

Commit

Permalink
docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed May 12, 2023
1 parent 59b30f0 commit 5edaff9
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mace_jax/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def evaluate(
data_loader: data.GraphDataLoader,
name: str = "Evaluation",
) -> Tuple[float, Dict[str, Any]]:
r"""Evaluate the model on the given data loader.
Args:
model: function of signature `model(params, graph) -> {energy: [num_graphs], forces: [num_nodes, 3], stress: [num_graphs, 3, 3]}`
params: parameters of the model
loss_fn: function of signature `loss_fn(graph, output) -> loss` where `output` is the output of `model`
data_loader: data loader
"""
total_loss = 0.0
num_graphs = 0

Expand Down

0 comments on commit 5edaff9

Please sign in to comment.