diff --git a/tests/laplace/test_diag_ggn.py b/tests/laplace/test_diag_ggn.py index 5961e1c8..89a172b9 100644 --- a/tests/laplace/test_diag_ggn.py +++ b/tests/laplace/test_diag_ggn.py @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch): ) # validate args introduces control flows not yet supported in torch.func.vmap -def forward_m(params, batch, model): - y_pred = functional_call(model, params, batch[0]) +def forward_m(params, b, model): + y_pred = functional_call(model, params, b[0]) return y_pred, torch.tensor([])