Skip to content

Commit

Permalink
Update test to capture the new desired behaviour.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminDev committed May 24, 2024
1 parent f327548 commit 3ab0b67
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/laplace/test_diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch):
) # validate args introduces control flows not yet supported in torch.func.vmap


def forward_m(params, batch, model):
y_pred = functional_call(model, params, batch[0])
def forward_m(params, b, model):
y_pred = functional_call(model, params, b[0])
return y_pred, torch.tensor([])


Expand Down

0 comments on commit 3ab0b67

Please sign in to comment.