Skip to content

Commit

Permalink
fix recent changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 23, 2023
1 parent 273a22b commit 431390e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/aspirin.gin
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ datasets.valid_fraction = 0.208
datasets.n_node = 512
datasets.n_edge = 2048
datasets.n_graph = 6
datasets.n_mantissa_bits = 3
datasets.n_mantissa_bits = 2



Expand Down
2 changes: 1 addition & 1 deletion configs/aspirin_small.gin
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ datasets.valid_fraction = 0.208
datasets.n_node = 512
datasets.n_edge = 2048
datasets.n_graph = 6
datasets.n_mantissa_bits = 3
datasets.n_mantissa_bits = 2



Expand Down
6 changes: 3 additions & 3 deletions mace_jax/tools/gin_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def weight_decay_mask(params):

@gin.configurable
def train(
model,
predictor,
params,
optimizer_state,
train_loader,
Expand Down Expand Up @@ -212,7 +212,7 @@ def train(

for interval, params, optimizer_state, ema_params in tools.train(
params=params,
loss_fn=lambda params, graph: loss_fn(graph, model(params, graph)),
total_loss_fn=lambda params, graph: loss_fn(graph, predictor(params, graph)),
train_loader=train_loader,
gradient_transform=gradient_transform,
optimizer_state=optimizer_state,
Expand All @@ -237,7 +237,7 @@ def train(

def eval_and_print(loader, mode: str):
loss_, metrics_ = tools.evaluate(
model=model,
predictor=predictor,
params=ema_params,
loss_fn=loss_fn,
data_loader=loader,
Expand Down

0 comments on commit 431390e

Please sign in to comment.