diff --git a/l4casadi/realtime/sensitivities.py b/l4casadi/realtime/sensitivities.py index 36cb5e8..f6245ad 100644 --- a/l4casadi/realtime/sensitivities.py +++ b/l4casadi/realtime/sensitivities.py @@ -32,7 +32,9 @@ def batched_jacobian(func: Callable, inputs: torch.Tensor, create_graph=False, r if inputs.shape[0] == 1: vmap_randomness = 'same' else: - vmap_randomness = 'different' + # https://github.com/pytorch/functorch/issues/996 + # Should be 'different' + vmap_randomness = 'same' if not create_graph: with torch.no_grad(): @@ -65,7 +67,9 @@ def batched_hessian(func: Callable, inputs: torch.Tensor, create_graph=False, if inputs.shape[0] == 1: vmap_randomness = 'same' else: - vmap_randomness = 'different' + # https://github.com/pytorch/functorch/issues/996 + # Should be 'different' + vmap_randomness = 'same' def aux_function_jac(func): def inner_aux(inputs):