Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in your code #89

Open
ghost opened this issue Jul 3, 2021 · 2 comments
Open

Error in your code #89

ghost opened this issue Jul 3, 2021 · 2 comments

Comments

@ghost
Copy link

ghost commented Jul 3, 2021

When I ran your code from: https://github.com/VLL-HD/FrEIA#tutorial
in1 = Ff.InputNode(100, name='Input 1') # 1D vector
in2 = Ff.InputNode(20, name='Input 2') # 1D vector
cond = Ff.ConditionNode(42, name='Condition')

def subnet(dims_in, dims_out):
return nn.Sequential(nn.Linear(dims_in, 256), nn.ReLU(),
nn.Linear(256, dims_out))

perm = Ff.Node(in1, Fm.PermuteRandom, {}, name='Permutation')
split1 = Ff.Node(perm, Fm.Split, {}, name='Split 1')
split2 = Ff.Node(split1.out1, Fm.Split, {}, name='Split 2')
actnorm = Ff.Node(split2.out1, Fm.ActNorm, {}, name='ActNorm')
concat1 = Ff.Node([actnorm.out0, in2.out0], Fm.Concat, {}, name='Concat 1')
affine = Ff.Node(concat1, Fm.AffineCouplingOneSided, {'subnet_constructor': subnet},
conditions=cond, name='Affine Coupling')
concat2 = Ff.Node([split2.out0, affine.out0], Fm.Concat, {}, name='Concat 2')

output1 = Ff.OutputNode(split1.out0, name='Output 1')
output2 = Ff.OutputNode(concat2, name='Output 2')

example_INN = Ff.GraphINN([in1, in2, cond,
perm, split1, split2,
actnorm, concat1, affine, concat2,
output1, output2])

dummy inputs:

x1, x2, c = torch.randn(1, 100), torch.randn(1, 20), torch.randn(1, 42)

compute the outputs

(z1, z2), log_jac_det = example_INN([x1, x2], c=c)

invert the network and check if we get the original inputs back:

(x1_inv, x2_inv), log_jac_det_inv = example_INN([z1, z2], c=c, rev=True)

#x2_inv has all Nan
print(x2_inv)

@ardizzone
Copy link
Member

Hi!

Thanks for pointing it out, I thought we already diagnosed and fixed the issue at some point, but it seems like we haven't merged it yet.
The problem is that actnorm (= batch norm for INNs) is used with batch size 1 (where the standard deviation can't be computed).

If you give all the dummy inputs a larger batch size than 1, it works.
E.g. x1, x2, c = torch.randn(8, 100), torch.randn(8, 20), torch.randn(8, 42)

I will try to find out what happened to the fix in the meantime.

@ghost
Copy link
Author

ghost commented Jul 6, 2021

Thanks @ardizzone for your help! It works now and hope you will fix the problem for batch size = 1.

@ardizzone ardizzone reopened this Jul 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant