diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 0a74570..50a6ee6 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -239,6 +239,9 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' + if x.shape[1:] != self.dims_in[0][1:]: + raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " + f"got {x.shape}.") if self.householder: self.w_perm = self._construct_householder_permutation() if rev or self.reverse_pre_permute: diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index 44d46fb..ad20009 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -94,19 +94,14 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ assert default_domain[3] - default_domain[2] >= min_bin_sizes[1] * bins, \ "{bins} bins of size {min_bin_sizes[1]} are too large for domain {default_domain[2]} to {default_domain[3]}" - if domain_clamping is not None: - self.clamp_domain = lambda domain: domain_clamping * torch.tanh( - domain / domain_clamping - ) - else: - self.clamp_domain = lambda domain: domain - self.register_buffer("bins", torch.tensor(bins, dtype=torch.int32)) self.register_buffer("min_bin_sizes", torch.as_tensor(min_bin_sizes, dtype=torch.float32)) self.register_buffer("default_domain", torch.as_tensor(default_domain, dtype=torch.float32)) self.register_buffer("identity_tails", torch.tensor(identity_tails, dtype=torch.bool)) self.register_buffer("default_width", torch.as_tensor(default_domain[1] - default_domain[0], dtype=torch.float32)) + self.domain_clamping = domain_clamping + # The default parameters are # parameter constraints count # 1. the leftmost bin edge - 1 @@ -140,6 +135,15 @@ def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str return dict(zip(keys, values)) + def clamp_domain(self, domain: torch.Tensor) -> torch.Tensor: + """ + Clamp domain to the a size between (-domain_clamping, domain_clamping) + """ + if self.domain_clamping is None: + return domain + else: + return self.domain_clamping * torch.tanh(domain / self.domain_clamping) + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Constrain Parameters to meet certain conditions (e.g. positivity) diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index 51b5fea..e8673d0 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -164,7 +164,8 @@ def rational_quadratic_spline(x: torch.Tensor, # Eq 29 in the appendix of the paper discriminant = b ** 2 - 4 * a * c - assert torch.all(discriminant >= 0), f"Discriminant must be positive, but is violated by {torch.min(discriminant)}" + if not torch.all(discriminant >= 0): + raise(RuntimeError(f"Discriminant must be positive, but is violated by {torch.min(discriminant)}")) xi = 2 * c / (-b - torch.sqrt(discriminant))