Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grad check failed #3

Open
justanhduc opened this issue Jan 11, 2020 · 2 comments
Open

Grad check failed #3

justanhduc opened this issue Jan 11, 2020 · 2 comments

Comments

@justanhduc
Copy link

Your implementation failed grad check.

import torch as T
from torch.autograd import gradcheck
x = T.rand(2, 4, 3).cuda().double().requires_grad_(True)
y = T.rand(2, 5, 3).cuda().double()
print(gradcheck(earth_mover_distance, (x, y)))

One bug is perhaps here and here. Probably you want to cast a value to scalar_t before the division.
I am not familiar with CUDA so couldn't get any further. You have any idea how to solve this?

@daerduoCarey
Copy link
Owner

You need to input the point-clouds as in the Bx3xN format, since by default, earth_mover_distance has transpose=True.

@justanhduc
Copy link
Author

my bad for the confusion. i fixed transpose=True by default in the code. otherwise, the code would throw a shape error.
As a reproducible example, this code

import torch
import numpy as np
import emd_cuda


class EarthMoverDistanceFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz1, xyz2):
        xyz1 = xyz1.contiguous()
        xyz2 = xyz2.contiguous()
        assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
        match = emd_cuda.approxmatch_forward(xyz1, xyz2)
        cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
        ctx.save_for_backward(xyz1, xyz2, match)
        return cost

    @staticmethod
    def backward(ctx, grad_cost):
        xyz1, xyz2, match = ctx.saved_tensors
        grad_cost = grad_cost.contiguous()
        grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
        return grad_xyz1, grad_xyz2


def earth_mover_distance(xyz1, xyz2, transpose=False):
    """Earth Mover Distance (Approx)

    Args:
        xyz1 (torch.Tensor): (b, 3, n1)
        xyz2 (torch.Tensor): (b, 3, n1)
        transpose (bool): whether to transpose inputs as it might be BCN format.
            Extensions only support BNC format.

    Returns:
        cost (torch.Tensor): (b)

    """
    if xyz1.dim() == 2:
        xyz1 = xyz1.unsqueeze(0)
    if xyz2.dim() == 2:
        xyz2 = xyz2.unsqueeze(0)
    if transpose:
        xyz1 = xyz1.transpose(1, 2)
        xyz2 = xyz2.transpose(1, 2)
    cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
    return cost


import torch as T
from torch.autograd import gradcheck
x = T.arange(2 * 4 * 3).cuda().double().reshape(2, 4, 3).requires_grad_(True)
y = T.arange(2 * 5 * 3).cuda().double().reshape(2, 5, 3)
print(gradcheck(earth_mover_distance, (x, y)))

throws

Traceback (most recent call last):
  File "/home/justanhduc/Workspace/cpp/emd-cpp/test.py", line 93, in <module>
    print(gradcheck(earth_mover_distance, (x, y)))
  File "/home/justanhduc/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 289, in gradcheck
    'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
  File "/home/justanhduc/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 227, in fail_test
    raise RuntimeError(msg)
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ -6.0000,   0.0000],
        [ -6.0000,   0.0000],
        [ -6.0000,   0.0000],
        [ -4.5000,   0.0000],
        [ -4.5000,   0.0000],
        [ -4.5000,   0.0000],
        [ -3.6035,   0.0000],
        [ -3.6035,   0.0000],
        [ -3.6035,   0.0000],
        [ -1.4344,   0.0000],
        [ -1.4344,   0.0000],
        [ -1.4344,   0.0000],
        [  0.0000, -34.2869],
        [  0.0000, -34.2869],
        [  0.0000, -34.2869],
        [  0.0000,  -5.3574],
        [  0.0000,  -5.3574],
        [  0.0000,  -5.3574],
        [  0.0000,  -1.2887],
        [  0.0000,  -1.2887],
        [  0.0000,  -1.2887],
        [  0.0000,   2.3111],
        [  0.0000,   2.3111],
        [  0.0000,   2.3111]], dtype=torch.float64)
analytical:tensor([[ -6.0000,  -0.0000],
        [ -6.0000,  -0.0000],
        [ -6.0000,  -0.0000],
        [ -4.5000,  -0.0000],
        [ -4.5000,  -0.0000],
        [ -4.5000,  -0.0000],
        [ -3.0000,  -0.0000],
        [ -3.0000,  -0.0000],
        [ -3.0000,  -0.0000],
        [ -1.5000,  -0.0000],
        [ -1.5000,  -0.0000],
        [ -1.5000,  -0.0000],
        [ -0.0000, -34.2869],
        [ -0.0000, -34.2869],
        [ -0.0000, -34.2869],
        [ -0.0000,  -5.3574],
        [ -0.0000,  -5.3574],
        [ -0.0000,  -5.3574],
        [ -0.0000,  -3.8558],
        [ -0.0000,  -3.8558],
        [ -0.0000,  -3.8558],
        [ -0.0000,  -1.5000],
        [ -0.0000,  -1.5000],
        [ -0.0000,  -1.5000]], dtype=torch.float64)


Process finished with exit code 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants