Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaveMix Multi-GPU CUDA fix #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions wavemix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from einops.layers.torch import Rearrange


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
""" 1D synthesis filter bank of an image tensor
"""
Expand Down Expand Up @@ -57,6 +55,7 @@ def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):

return y


def reflect(x, minx, maxx):
"""Reflect the values in matrix *x* about the scalar values *minx* and
*maxx*. Hence a vector *x* containing a long linearly increasing series is
Expand All @@ -74,6 +73,7 @@ def reflect(x, minx, maxx):
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
return np.array(out, dtype=x.dtype)


def mode_to_int(mode):
if mode == 'zero':
return 0
Expand All @@ -92,6 +92,7 @@ def mode_to_int(mode):
else:
raise ValueError("Unkown pad type: {}".format(mode))


def int_to_mode(mode):
if mode == 0:
return 'zero'
Expand All @@ -110,6 +111,7 @@ def int_to_mode(mode):
else:
raise ValueError("Unkown pad type: {}".format(mode))


def afb1d(x, h0, h1, mode='zero', dim=-1):
""" 1D analysis filter bank (along one dimension only) of an image
Inputs:
Expand Down Expand Up @@ -160,7 +162,7 @@ def afb1d(x, h0, h1, mode='zero', dim=-1):
N += 1
x = roll(x, -L2, dim=d)
pad = (L-1, 0) if d == 2 else (0, L-1)
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C)
N2 = N//2
if d == 2:
lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
Expand All @@ -181,18 +183,17 @@ def afb1d(x, h0, h1, mode='zero', dim=-1):
x = F.pad(x, pad)
pad = (p//2, 0) if d == 2 else (0, p//2)
# Calculate the high and lowpass
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C)
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
x = mypad(x, pad=pad, mode=mode)
lohi = F.conv2d(x, h, stride=s, groups=C)
lohi = F.conv2d(x, h.to(x.device), stride=s, groups=C)
else:
raise ValueError("Unkown pad type: {}".format(mode))

return lohi



class AFB2D(Function):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
Expand Down Expand Up @@ -245,7 +246,7 @@ def backward(ctx, low, highs):
return dx, None, None, None, None, None


def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device):
def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device='cpu'):
"""
Prepares the filters to be of the right form for the afb2d function. In
particular, makes the tensors the right shape. It takes mirror images of
Expand Down Expand Up @@ -274,7 +275,7 @@ def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device):
return h0_col, h1_col, h0_row, h1_row


def prep_filt_afb1d(h0, h1, device=device):
def prep_filt_afb1d(h0, h1, device='cpu'):
"""
Prepares the filters to be of the right form for the afb2d function. In
particular, makes the tensors the right shape. It takes mirror images of
Expand All @@ -293,6 +294,7 @@ def prep_filt_afb1d(h0, h1, device=device):
h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
return h0, h1


class DWTForward(nn.Module):
""" Performs a 2d DWT Forward decomposition of an image
Args:
Expand Down Expand Up @@ -358,12 +360,20 @@ def forward(self, x):

return ll, yh


from numpy.lib.function_base import hamming

xf1 = DWTForward(J=1, mode='zero', wave='db1').to(device)
xf2 = DWTForward(J=2, mode='zero', wave='db1').to(device)
xf3 = DWTForward(J=3, mode='zero', wave='db1').to(device)
xf4 = DWTForward(J=4, mode='zero', wave='db1').to(device)


def get_dwt_filters(level, mode='zero', wave='db1'):
xf = []
for j in range(1,level+1,1):
xf.append(DWTForward(J=j, mode=mode, wave=wave))

if level == 1:
xf = xf[0]

return xf


class Level1Waveblock(nn.Module):
def __init__(
Expand All @@ -388,14 +398,15 @@ def __init__(
)

self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)

self.xf1 = get_dwt_filters(level=1)


def forward(self, x):
b, c, h, w = x.shape

x = self.reduction(x)

Y1, Yh = xf1(x)
Y1, Yh = self.xf1(x)

x = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))

Expand All @@ -405,6 +416,7 @@ def forward(self, x):

return x


class Level2Waveblock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -435,15 +447,16 @@ def __init__(
)

self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)

self.xf1, self.xf2 = get_dwt_filters(level=2)


def forward(self, x):
b, c, h, w = x.shape

x = self.reduction(x)

Y1, Yh = xf1(x)
Y2, Yh = xf2(x)
Y1, Yh = self.xf1(x)
Y2, Yh = self.xf2(x)


x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
Expand Down Expand Up @@ -499,16 +512,17 @@ def __init__(
)

self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)

self.xf1, self.xf2, self.xf3 = get_dwt_filters(level=3)


def forward(self, x):
b, c, h, w = x.shape

x = self.reduction(x)

Y1, Yh = xf1(x)
Y2, Yh = xf2(x)
Y3, Yh = xf3(x)
Y1, Yh = self.xf1(x)
Y2, Yh = self.xf2(x)
Y3, Yh = self.xf3(x)


x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
Expand Down Expand Up @@ -582,17 +596,18 @@ def __init__(
)

self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)

self.xf1, self.xf2, self.xf3, self.xf4 = get_dwt_filters(level=4)


def forward(self, x):
b, c, h, w = x.shape

x = self.reduction(x)

Y1, Yh = xf1(x)
Y2, Yh = xf2(x)
Y3, Yh = xf3(x)
Y4, Yh = xf4(x)
Y1, Yh = self.xf1(x)
Y2, Yh = self.xf2(x)
Y3, Yh = self.xf3(x)
Y4, Yh = self.xf4(x)

x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
Expand Down
20 changes: 10 additions & 10 deletions wavemix/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from einops.layers.torch import Rearrange


class WaveMix(nn.Module):
def __init__(
self,
Expand All @@ -22,14 +23,14 @@ def __init__(

self.layers = nn.ModuleList([])
for _ in range(depth):
if level == 4:
self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
elif level == 3:
self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
elif level == 2:
self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
else:
self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
if level == 4:
self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
elif level == 3:
self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
elif level == 2:
self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
else:
self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))

self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Expand All @@ -49,8 +50,7 @@ def __init__(
nn.Conv2d(int(final_dim/2), final_dim, patch_size, patch_size),
nn.GELU(),
nn.BatchNorm2d(final_dim)
)

)

def forward(self, img):
x = self.conv(img)
Expand Down