You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for pointing it out, I thought we already diagnosed and fixed the issue at some point, but it seems like we haven't merged it yet.
The problem is that actnorm (= batch norm for INNs) is used with batch size 1 (where the standard deviation can't be computed).
If you give all the dummy inputs a larger batch size than 1, it works.
E.g. x1, x2, c = torch.randn(8, 100), torch.randn(8, 20), torch.randn(8, 42)
I will try to find out what happened to the fix in the meantime.
When I ran your code from: https://github.com/VLL-HD/FrEIA#tutorial
in1 = Ff.InputNode(100, name='Input 1') # 1D vector
in2 = Ff.InputNode(20, name='Input 2') # 1D vector
cond = Ff.ConditionNode(42, name='Condition')
def subnet(dims_in, dims_out):
return nn.Sequential(nn.Linear(dims_in, 256), nn.ReLU(),
nn.Linear(256, dims_out))
perm = Ff.Node(in1, Fm.PermuteRandom, {}, name='Permutation')
split1 = Ff.Node(perm, Fm.Split, {}, name='Split 1')
split2 = Ff.Node(split1.out1, Fm.Split, {}, name='Split 2')
actnorm = Ff.Node(split2.out1, Fm.ActNorm, {}, name='ActNorm')
concat1 = Ff.Node([actnorm.out0, in2.out0], Fm.Concat, {}, name='Concat 1')
affine = Ff.Node(concat1, Fm.AffineCouplingOneSided, {'subnet_constructor': subnet},
conditions=cond, name='Affine Coupling')
concat2 = Ff.Node([split2.out0, affine.out0], Fm.Concat, {}, name='Concat 2')
output1 = Ff.OutputNode(split1.out0, name='Output 1')
output2 = Ff.OutputNode(concat2, name='Output 2')
example_INN = Ff.GraphINN([in1, in2, cond,
perm, split1, split2,
actnorm, concat1, affine, concat2,
output1, output2])
dummy inputs:
x1, x2, c = torch.randn(1, 100), torch.randn(1, 20), torch.randn(1, 42)
compute the outputs
(z1, z2), log_jac_det = example_INN([x1, x2], c=c)
invert the network and check if we get the original inputs back:
(x1_inv, x2_inv), log_jac_det_inv = example_INN([z1, z2], c=c, rev=True)
#x2_inv has all Nan
print(x2_inv)
The text was updated successfully, but these errors were encountered: