Skip to content

Commit

Permalink
Adjusted ActNorm to work as described in the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Sep 26, 2023
1 parent 4ebc35c commit 84cfcb6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
22 changes: 19 additions & 3 deletions FrEIA/modules/invertible_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None):

self.register_buffer("is_initialized", torch.tensor(False))

dims = next(iter(dims_in))
dims = list(next(iter(dims_in)))
dims[2:] = [1] * len(dims[2:])
self.log_scale = nn.Parameter(torch.empty(1, *dims))
self.loc = nn.Parameter(torch.empty(1, *dims))

Expand All @@ -42,9 +43,24 @@ def scale(self):
return torch.exp(self.log_scale)

def initialize(self, batch: torch.Tensor):
if batch.ndim != self.log_scale.ndim:
raise ValueError(f"Expected batch of dimension {self.log_scale.ndim}, but got {batch.ndim}.")

# we draw the mean and std over all dimensions except the channel dimension
dims = [0] + list(range(2, batch.ndim))

loc = torch.mean(batch, dim=dims, keepdim=True)
scale = torch.std(batch, dim=dims, keepdim=True)

# check for zero std
if torch.any(torch.isclose(scale, torch.tensor(0.0))):
raise ValueError("Failed to initialize ActNorm: One or more channels have zero standard deviation.")

# slice here to avoid silent device move
self.log_scale.data[:] = torch.log(scale)
self.loc.data[:] = loc

self.is_initialized.data = torch.tensor(True)
self.log_scale.data = torch.log(torch.std(batch, dim=0, keepdim=True))
self.loc.data = torch.mean(batch, dim=0, keepdim=True)

def output_dims(self, input_dims):
assert len(input_dims) == 1, "Can only use one input"
Expand Down
8 changes: 6 additions & 2 deletions tests/test_invertible_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def test_conv(self):
self.assertStandardMoments(y_)

def assertStandardMoments(self, data):
self.assertTrue(torch.allclose(torch.mean(data, dim=0), torch.zeros(data.shape[-1]), atol=1e-7))
self.assertTrue(torch.allclose(torch.std(data, dim=0), torch.ones(data.shape[-1])))
dims = [0] + list(range(2, data.ndim))
mean = torch.mean(data, dim=dims)
std = torch.std(data, dim=dims)

self.assertTrue(torch.allclose(mean, torch.zeros_like(mean), atol=1e-7))
self.assertTrue(torch.allclose(std, torch.ones_like(std)))


class IResNetTest(unittest.TestCase):
Expand Down

0 comments on commit 84cfcb6

Please sign in to comment.