diff --git a/FrEIA/modules/coupling_layers.py b/FrEIA/modules/coupling_layers.py index 5757856..8a4e3ba 100644 --- a/FrEIA/modules/coupling_layers.py +++ b/FrEIA/modules/coupling_layers.py @@ -462,6 +462,73 @@ def forward(self, x, c=[], rev=False, jac=True): return (torch.cat((x1, y2), 1),), j +class AffineCouplingLopSided(coupling_layers._BaseCouplingBlock): + '''A lop-sided coupling block derived from the GLOWCouplingBlock design, + performing an affine transformation on one half of the inputs and an additive + transformation on the other. As used in IRN (Xiao et al, 2020), one could + use the additive half as, effectively, a condition for the affine half.''' + + def __init__(self, dims_in, dims_c=[], + subnet_constructor: Callable = None, + clamp: float = 2., + clamp_activation: Union[str, Callable] = "SIGMOID", + split_len: Union[float, int] = 3): + ''' + Additional args in docstring of base class. + Args: + subnet_constructor: function or class, with signature + constructor(dims_in, dims_out). The result should be a torch + nn.Module, that takes dims_in input channels, and dims_out output + channels. See tutorial for examples. Three subnetworks will be + initialized in the block. + clamp: Soft clamping for the multiplicative component. The + amplification or attenuation of each input dimension can be at most + exp(±clamp). + clamp_activation: Function to perform the clamping. String values + "ATAN", "TANH", and "SIGMOID" are recognized, or a function of + object can be passed. TANH behaves like the original realNVP paper. + A custom function should take tensors and map -inf to -1 and +inf to +1. + ''' + + super().__init__(dims_in, dims_c, clamp, clamp_activation, + split_len=split_len) + self.phi = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1) + self.rho = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2) + self.mu = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2) + + def forward(self, x, c=[], rev=False, jac=True): + x1, x2 = torch.split(x[0], [self.split_len1, self.split_len2], dim=1) + x1_c = torch.cat([x1, *c], 1) if self.conditional else x1 + x2_c = torch.cat([x2, *c], 1) if self.conditional else x2 + + # notation: + # x1, x2: two halves of the input + # y1, y2: two halves of the output + # rho, mu: multiplicative, additive subnets for y2 + # phi: additive subnet for y1 + # s: multiplicative coefficient for y2 + # j: log det Jacobian + + if rev: + s = self.rho(x1_c) + s = self.clamp * self.f_clamp(s) + y2 = (x2 - self.mu(x1_c)) * torch.exp(-s) + y2_c = torch.cat([y2, *c], 1) if self.conditional else y2 + + y1 = x1 - self.phi(y2_c) + + j = -1 * torch.sum(s) + else: + y1 = x1 + self.phi(x2_c) + y1_c = torch.cat([y1, *c], 1) if self.conditional else y1 + + s = self.rho(y1_c) + s = self.clamp * self.f_clamp(s) + y2 = x2 * torch.exp(s) + self.mu(y1_c) + + j = torch.sum(s) + + return (torch.cat((y1, y2), 1),), j class ConditionalAffineTransform(_BaseCouplingBlock): '''Similar to the conditioning layers from SPADE (Park et al, 2019): Perform