Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce lop-sided affine coupling (IRN, 2020) #115

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions FrEIA/modules/coupling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,73 @@ def forward(self, x, c=[], rev=False, jac=True):

return (torch.cat((x1, y2), 1),), j

class AffineCouplingLopSided(coupling_layers._BaseCouplingBlock):
'''A lop-sided coupling block derived from the GLOWCouplingBlock design,
performing an affine transformation on one half of the inputs and an additive
transformation on the other. As used in IRN (Xiao et al, 2020), one could
use the additive half as, effectively, a condition for the affine half.'''

def __init__(self, dims_in, dims_c=[],
subnet_constructor: Callable = None,
clamp: float = 2.,
clamp_activation: Union[str, Callable] = "SIGMOID",
split_len: Union[float, int] = 3):
'''
Additional args in docstring of base class.
Args:
subnet_constructor: function or class, with signature
constructor(dims_in, dims_out). The result should be a torch
nn.Module, that takes dims_in input channels, and dims_out output
channels. See tutorial for examples. Three subnetworks will be
initialized in the block.
clamp: Soft clamping for the multiplicative component. The
amplification or attenuation of each input dimension can be at most
exp(±clamp).
clamp_activation: Function to perform the clamping. String values
"ATAN", "TANH", and "SIGMOID" are recognized, or a function of
object can be passed. TANH behaves like the original realNVP paper.
A custom function should take tensors and map -inf to -1 and +inf to +1.
'''

super().__init__(dims_in, dims_c, clamp, clamp_activation,
split_len=split_len)
self.phi = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1)
self.rho = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2)
self.mu = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2)

def forward(self, x, c=[], rev=False, jac=True):
x1, x2 = torch.split(x[0], [self.split_len1, self.split_len2], dim=1)
x1_c = torch.cat([x1, *c], 1) if self.conditional else x1
x2_c = torch.cat([x2, *c], 1) if self.conditional else x2

# notation:
# x1, x2: two halves of the input
# y1, y2: two halves of the output
# rho, mu: multiplicative, additive subnets for y2
# phi: additive subnet for y1
# s: multiplicative coefficient for y2
# j: log det Jacobian

if rev:
s = self.rho(x1_c)
s = self.clamp * self.f_clamp(s)
y2 = (x2 - self.mu(x1_c)) * torch.exp(-s)
y2_c = torch.cat([y2, *c], 1) if self.conditional else y2

y1 = x1 - self.phi(y2_c)

j = -1 * torch.sum(s)
else:
y1 = x1 + self.phi(x2_c)
y1_c = torch.cat([y1, *c], 1) if self.conditional else y1

s = self.rho(y1_c)
s = self.clamp * self.f_clamp(s)
y2 = x2 * torch.exp(s) + self.mu(y1_c)

j = torch.sum(s)

return (torch.cat((y1, y2), 1),), j

class ConditionalAffineTransform(_BaseCouplingBlock):
'''Similar to the conditioning layers from SPADE (Park et al, 2019): Perform
Expand Down