Skip to content

Commit

Permalink
safer context enter
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro committed Oct 17, 2024
1 parent 0781b46 commit 797eb54
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
6 changes: 1 addition & 5 deletions src/caskade/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@


class ActiveContext:
def __init__(
self, module: Module, params: Union[Sequence[Tensor], Mapping[str, Tensor], Tensor]
):
def __init__(self, module: Module):
self.module = module
self.params = params

def __enter__(self):
self.module.active = True
self.module.fill_params(self.params)

def __exit__(self, exc_type, exc_value, traceback):
self.module.clear_params()
Expand Down
3 changes: 2 additions & 1 deletion src/caskade/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def wrapped(self, *args, **kwargs):
f"Params must be provided for dynamic modules. Expected {len(self.dynamic_params)} params."
)

with ActiveContext(self, params):
with ActiveContext(self):
self.fill_params(params)
kwargs.update(self.fill_kwargs(method_kwargs))
return method(self, *args, **kwargs)

Expand Down
5 changes: 2 additions & 3 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,11 @@ def fill_params(self, params):
*B, _ = params.shape
pos = 0
for param in self.dynamic_params:
try:
size = max(1, prod(param.shape))
except TypeError:
if not isinstance(param.shape, tuple):
raise ValueError(
f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input."
)
size = max(1, prod(param.shape))
if self.batch:
param.value = params[..., pos : pos + size].view(
tuple(B) + ((1,) if param.shape == () else param.shape)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def __call__(self, d=None, e=None):
main1 = TestSim(2.0, (2, 2), LiveParam, (2,), sub1)

# Dont provide params
print(main1.active)
with pytest.raises(ValueError):
main1.testfun()
print(main1.active)

# List as params
params = [torch.ones((2, 2)), torch.tensor(3.0), torch.tensor(4.0), torch.tensor(1.0)]
Expand Down Expand Up @@ -100,8 +102,10 @@ def __call__(self, d=None, e=None):
# dynamic with no shape
main1.b = None
main1.b.shape = None
print(main1.active)
with pytest.raises(ValueError):
main1.testfun(1.0, params=torch.ones(4))
print(main1.active)
result = main1.testfun(1.0, params=[torch.ones((2, 2))])
assert result.shape == (2, 2)

Expand Down

0 comments on commit 797eb54

Please sign in to comment.