From 3ab0b67305e2295ed5bac2740b8dd9fa9e04732a Mon Sep 17 00:00:00 2001 From: benjaminDev Date: Fri, 24 May 2024 12:01:46 +0200 Subject: [PATCH] Update test to capture the new desired behaviour. --- tests/laplace/test_diag_ggn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/laplace/test_diag_ggn.py b/tests/laplace/test_diag_ggn.py index 5961e1c8..89a172b9 100644 --- a/tests/laplace/test_diag_ggn.py +++ b/tests/laplace/test_diag_ggn.py @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch): ) # validate args introduces control flows not yet supported in torch.func.vmap -def forward_m(params, batch, model): - y_pred = functional_call(model, params, batch[0]) +def forward_m(params, b, model): + y_pred = functional_call(model, params, b[0]) return y_pred, torch.tensor([])