From d523ff3491c2537a5cfd7847e991fa9c27731d2f Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 17 Oct 2024 23:05:19 -0400 Subject: [PATCH] add test to cover exception condition --- src/caskade/module.py | 20 ++++++++++---------- tests/test_forward.py | 6 ++++++ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/caskade/module.py b/src/caskade/module.py index cfa5caf..3ae3c3b 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -101,17 +101,17 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping]): ) # Handle scalar parameters size = max(1, prod(param.shape)) - if batch: - try: + try: + if batch: param.value = params[..., pos : pos + size].view(tuple(B) + param.shape) - except IndexError: - raise AssertionError( - f"Batched input params shape {params.shape} does not match dynamic params shape. Make sure the last dimension has size equal to the sum of all dynamic params sizes." - ) - pos += size - else: - param.value = params[pos : pos + size].view(param.shape) - pos += size + else: + param.value = params[pos : pos + size].view(param.shape) + except (RuntimeError, IndexError): + fullnumel = sum(max(1, prod(p.shape)) for p in self.dynamic_params) + raise AssertionError( + f"Input params shape {params.shape} does not match dynamic params shape. Make sure the last dimension has size equal to the sum of all dynamic params sizes ({fullnumel})." + ) + pos += size assert ( pos == params.shape[-1] ), f"Input params length {params.shape} does not match dynamic params length. Not all dynamic params were filled." diff --git a/tests/test_forward.py b/tests/test_forward.py index eb32e0d..041581b 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -52,6 +52,12 @@ def __call__(self, d=None, e=None, live_c=None): assert result.shape == (2, 2) result = main1.testfun(1.0, params) assert result.shape == (2, 2) + # Wrong number of params, too few + with pytest.raises(AssertionError): + result = main1.testfun(1.0, params[:-3]) + # Wrong number of params, too many + with pytest.raises(AssertionError): + result = main1.testfun(1.0, torch.cat((params, params))) # Batched tensor as params params = params.repeat(3, 1).unsqueeze(1)