-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correlation Between Spheres #50
Comments
It sounds good. I can't say like that what could be wrong ... maybe a transpose of conjugate missing?? |
Hi Mario! This is quite a fast reply! Thank you :) from numpy.lib.function_base import _percentile_dispatcher
import torch
import math
import numpy as np
from s2cnn.s2_ft import s2_rft
from s2cnn.soft.so3_fft import SO3_ifft_real
from s2cnn.soft.s2_fft import S2_fft_real
from utils.utils import fftshift3d
from data.simulation_3d import get_simulation_3d
from log_sphere.log_sphere import sphere_transformer
import numpy as np
from s2cnn import s2_mm
def unravel_indices(indices,shape):
"""Converts flat indices into unraveled coordinates in a target shape.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord = []
for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim
coord = torch.stack(coord[::-1], dim=-1)
return coord
template, source, rotation_gt, translation_gt, scale_gt = get_simulation_3d(50,50,50,1)
template = torch.tensor(template[0]).float().permute(3,0,1,2)
source = torch.tensor(source[0]).float().permute(3,0,1,2)
device = torch.device("cpu")
# create two tensors on a sphere with the shape of [b,feature_in,beta,alpha]
bw_in = 25
bw_out = 25
print("grid", source.shape)
sphere1 = sphere_transformer(source.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()
sphere2 = sphere_transformer(template.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()
# sphere with the size of [b, theta, phi]
sphere_1_fft = S2_fft_real.apply(sphere1,bw_out)
sphere_2_fft = S2_fft_real.apply(sphere2,bw_out)
z = s2_mm(sphere_1_fft, sphere_2_fft).unsqueeze(-2) # [l * m * n, batch, feature_out, complex]
z = SO3_ifft_real.apply(z) # [batch, feature_out, beta, alpha, gamma]
z = z.squeeze(1)
# z = fftshift3d(z).unsqueeze(1)
cor_argmax = torch.argmax(z.view(z.size(0), z.size(1), -1), -1)
index = unravel_indices(cor_argmax, (z.size(2), z.size(3), z.size(4)))
print(index) and the modified s2_mm is something like this: def s2_mm(x, y):
'''
:param x: [l * m, batch, complex]
:param y: [l * m, batch, complex]
:return: [l * m * n, batch, complex]
'''
from s2cnn.utils.complex import complex_mm
# assert y.size(3) == 2
# assert x.size(3) == 2
nbatch = x.size(1)
# nfeature_in = x.size(2)
# nfeature_out = y.size(2)
# assert y.size(1) == nfeature_in
nspec = x.size(0)
# assert y.size(0) == nspec
# if x.is_cuda:
# return _cuda_S2_mm.apply(x, y)
nl = round(nspec**0.5)
batch_list = []
for b in range(nbatch):
x_batch = x[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
y_batch = y[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
Fz_list = []
begin = 0
for l in range(nl):
L = 2 * l + 1
size = L
Fx = x_batch[begin:begin+size] # [m, 1, complex]
Fy = y_batch[begin:begin+size] # [m, 1, complex]
Fy = Fy.transpose(0, 1) # [1, m, complex]
Fy = Fy.contiguous()
Fz = complex_mm(Fx, Fy, conj_y=True) # [m, m, complex]
Fz = Fz.view(L * L, 2) # [m * m, complex]
# print('fffff', Fz.shape)
Fz_list.append(Fz)
begin += size
z_batch = torch.cat(Fz_list, 0) # [l * m * m, complex]
batch_list.append(z_batch)
z = torch.stack(batch_list, dim=1) # [l * m * m, batch, complex]
print('shape', z.shape)
return z When I tried to make Sphere1 = Sphere2, the output of S2fft is consistent while the argmax coordinate keeps randomizing. Hope this might gives you a hint on my mistakes? |
Hi Mario! For S2FFT def s2_fft(x, for_grad=False, b_out=None):
'''
:param x: [..., beta, alpha, complex]
:return: [l * m, ..., complex]
'''
assert x.size(-1) == 2
b_in = x.size(-2) // 2
assert x.size(-2) == 2 * b_in
assert x.size(-3) == 2 * b_in
if b_out is None:
b_out = b_in
assert b_out <= b_in
batch_size = x.size()[:-3]
x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex]
'''
:param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
:return: [l * m, batch, complex] (b_out**2, nbatch, 2)
'''
nspec = b_out ** 2
nbatch = x.size(0)
wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)
wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec)
x = torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))) # [batch, beta, m, complex]
# x = torch.fft.fft(x,1) # [batch, beta, m, complex]
output = x.new_empty((nspec, nbatch, 2))
cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index)
stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
cuda_kernel(block=(1024, 1, 1),
grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()],
stream=stream)
# for l in range(b_out):
# s = slice(l ** 2, l ** 2 + 2 * l + 1)
# xx = torch.cat((x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
# output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))
output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2)
return output
def s2_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round(nspec ** 0.5)
assert nspec == b_in ** 2
if b_out is None:
b_out = b_in
assert b_out >= b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
:return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec)
cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
cuda_kernel(block=(1024, 1, 1),
grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
stream=stream)
# [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
# output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
# for l in range(b_in):
# s = slice(l ** 2, l ** 2 + 2 * l + 1)
# out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
# output[:, :, :l + 1] += out[:, :, -l - 1:]
# if l > 0:
# output[:, :, -l:] += out[:, :, :l]
output = torch.view_as_real(torch.fft.ifft(torch.view_as_complex(output))) * output.size(-2) # [batch, beta, alpha, complex]
output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
return output
class S2_fft_real(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b_out=None): # pylint: disable=W
from s2cnn.utils.complex import as_complex
ctx.b_out = b_out
ctx.b_in = x.size(-1) // 2
return s2_fft(as_complex(x), b_out=ctx.b_out)
@staticmethod
def backward(ctx, grad_output): # pylint: disable=W
return s2_ifft(grad_output, for_grad=True, b_out=ctx.b_in)[..., 0], None For SO3_FFT def so3_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round((3 / 4 * nspec) ** (1 / 3))
assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
if b_out is None:
b_out = b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
:return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
# if x.is_cuda and x.dtype == torch.float32:
cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex]
# else:
# output.fill_(0)
# for l in range(min(b_in, b_out)):
# s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
# out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
# l1 = min(l, b_out - 1) # if b_out < b_in
# output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
# if l > 0:
# output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
# output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
# output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]
output = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(output), dim=[2,3])) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex]
# output = torch.view_as_real((torch.fft.ifft(output, 2) * output.size(-2) ** 2)[...,0]) # [batch, beta, alpha, gamma, complex]
return output
def so3_fft(x, for_grad=False, b_out=None):
'''
:param x: [..., beta, alpha, gamma, complex]
:return: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2, x.size()
b_in = x.size(-2) // 2
assert x.size(-2) == 2 * b_in
assert x.size(-3) == 2 * b_in
assert x.size(-4) == 2 * b_in
if b_out is None:
b_out = b_in
batch_size = x.size()[:-4]
# x = x.view(-1, 2 * b_in, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, gamma, complex]
'''
:param x: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_in, 2 b_in, 2 b_in, 2)
:return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2)
'''
nspec = b_out * (4 * b_out ** 2 - 1) // 3
nbatch = x.size(0)
wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) # [beta, l * m * n]
# x = torch.fft(x, 2) # [batch, beta, m, n, complex]
x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x),dim=[2,3]))
output = x.new_empty((nspec, nbatch, 2))
cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=False, device=x.device.index)
cuda_kernel(x, wigner, output) # [l * m * n, batch, complex]
output = output.view(-1, *batch_size, 2) # [l * m * n, ..., complex]
return output
class SO3_ifft_real(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b_out=None): # pylint: disable=W
nspec = x.size(0)
ctx.b_out = b_out
ctx.b_in = round((3 / 4 * nspec) ** (1 / 3))
return so3_ifft(x, b_out=ctx.b_out)
@staticmethod
def backward(ctx, grad_output): # pylint: disable=W
output = so3_fft(grad_output, for_grad=True, b_out=ctx.b_in).unsqueeze(-2)
return output, None |
Very nice! This will help some people. Yes please make a PR! I will revert the last merge such that I will be able to compare the original code with your in the PR |
Hi,
Thank you so much for your wonderful work, really appreciate it! They are really easy to use. However, I did encounter some problems when trying to calculate the rotation R in SO(3) between two rotated spheres.
Basically, I followed the s2cnn/s2cnn/soft/s2_conv.py, and changed the torch kernel y with another sphere.
Now that Sphere1 and Sphere2 are passed into S2_fft_real.apply() with the results of Sphere1_FFT and Sphere2_FFT. The correlation result is then calculated by s2_mm(Sphere1, Sphere2) with slight modification in channels and shape. Then the correlation is passed to the SO3_ifft_real.apply() with the result argmax of the ZYZ angles.
I was wondering if this is the correct way of using the code to calculate rotations between two rotated spheres because by now the result seems incorrect.
Thanks in advance!!!
The text was updated successfully, but these errors were encountered: