Skip to content

Commit

Permalink
add test to cover exception condition
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro committed Oct 18, 2024
1 parent 5051025 commit d523ff3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping]):
)
# Handle scalar parameters
size = max(1, prod(param.shape))
if batch:
try:
try:
if batch:
param.value = params[..., pos : pos + size].view(tuple(B) + param.shape)
except IndexError:
raise AssertionError(
f"Batched input params shape {params.shape} does not match dynamic params shape. Make sure the last dimension has size equal to the sum of all dynamic params sizes."
)
pos += size
else:
param.value = params[pos : pos + size].view(param.shape)
pos += size
else:
param.value = params[pos : pos + size].view(param.shape)
except (RuntimeError, IndexError):
fullnumel = sum(max(1, prod(p.shape)) for p in self.dynamic_params)
raise AssertionError(
f"Input params shape {params.shape} does not match dynamic params shape. Make sure the last dimension has size equal to the sum of all dynamic params sizes ({fullnumel})."
)
pos += size
assert (
pos == params.shape[-1]
), f"Input params length {params.shape} does not match dynamic params length. Not all dynamic params were filled."
Expand Down
6 changes: 6 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def __call__(self, d=None, e=None, live_c=None):
assert result.shape == (2, 2)
result = main1.testfun(1.0, params)
assert result.shape == (2, 2)
# Wrong number of params, too few
with pytest.raises(AssertionError):
result = main1.testfun(1.0, params[:-3])
# Wrong number of params, too many
with pytest.raises(AssertionError):
result = main1.testfun(1.0, torch.cat((params, params)))

# Batched tensor as params
params = params.repeat(3, 1).unsqueeze(1)
Expand Down

0 comments on commit d523ff3

Please sign in to comment.