Skip to content

Commit

Permalink
Full JAX UNet implementation. Needs debugging #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 9, 2022
1 parent 149c80e commit 66b6b19
Showing 1 changed file with 148 additions and 10 deletions.
158 changes: 148 additions & 10 deletions raygun/jax/networks/UNet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#%%
import math
import numpy as np
import jax
import haiku as hk
#%%


class ConvPass(hk.Module):

def __init__(
Expand All @@ -14,7 +14,7 @@ def __init__(
activation,
padding='VALID',
residual=False,
padding_mode='reflect',
# padding_mode='reflect',
norm_layer=None,
data_format='NCDHW'):

Expand Down Expand Up @@ -109,7 +109,8 @@ def forward(self, x):
else:
init_x = self.x_init_map(x)
return self.activation(init_x + res)
#%%


class ConvDownsample(hk.Module):

def __init__(
Expand All @@ -120,7 +121,7 @@ def __init__(
downsample_factor,
activation,
padding='valid',
padding_mode='reflect',
# padding_mode='reflect',
norm_layer=None,
data_format='NCDHW'):

Expand Down Expand Up @@ -176,7 +177,8 @@ def __init__(

def forward(self, x):
return self.conv_pass(x)
#%%


class MaxDownsample(hk.Module): # TODO: check data format type

def __init__(
Expand Down Expand Up @@ -213,10 +215,9 @@ def check_mismatch(self, size):
size,
self.downsample_factor,
self.dims - d))

return

#%%

class Upsample(hk.module):

def __init__(
Expand All @@ -226,7 +227,7 @@ def __init__(
output_nc=None,
crop_factor=None,
next_conv_kernel_sizes=None,
data_format=None):
data_format='NCDHW'):

super().__init__()

Expand Down Expand Up @@ -329,6 +330,7 @@ def forward(self, f_left, g_out):

return jax.lax.concatenate((f_cropped, g_cropped), dimension=1)


class UNet(hk.Module):

def __init__(self,
Expand All @@ -338,6 +340,7 @@ def __init__(self,
kernel_size_down=None,
kernel_size_up=None,
activation='relu',
input_nc=None,
output_nc=None,
num_heads=1,
constant_upsample=False,
Expand All @@ -350,4 +353,139 @@ def __init__(self,
# voxel_size=(1, 1, 1),
# num_fmaps_out=None
):
pass

super().__init__()
self.ndims = len(downsample_factors[0])
self.num_levels = len(downsample_factors) + 1
self.num_heads = num_heads
self.input_nc = input_nc
self.output_nc = output_nc if output_nc else ngf
self.residual = residual
# if add_noise == 'param': # add noise feature if necessary
# self.noise_layer = ParameterizedNoiseBlock()
# elif add_noise:
# self.noise_layer = NoiseBlock() # TODO add utils methods
# else:
# self.noise_layer = None

if kernel_size_down is None:
kernel_size_down = [[(3,)*self.ndims, (3,)*self.ndims]]*self.num_levels
if kernel_size_up is None:
kernel_size_up = [[(3,)*self.ndims, (3,)*self.ndims]]*(self.num_levels - 1)

crop_factors = []
factor_product = None
for factor in downsample_factors[::-1]:
if padding_type.lower() == 'valid':
if factor_product is None:
factor_product = list(factor)
else:
factor_product = list(
f*ff
for f, ff in zip(factor, factor_product))
elif padding_type.lower() == 'same':
factor_product = None
else:
raise f'Invalid padding_type option: {padding_type}'
crop_factors.append(factor_product)
crop_factors = crop_factors[::-1]

# Left pass
self.l_conv = [ConvPass(input_nc
if level == 0
else ngf*fmap_inc_factor**(level - (downsample_method.lower() == 'max')),
ngf*fmap_inc_factor**level,
kernel_size_down[level],
activation=activation,
padding=padding_type,
residual=self.residual,
norm_layer=norm_layer)
for level in range(self.num_levels)
]
self.dims = self.l_conv[0].dims

# Left downsample
if downsample_method.lower() == 'max':
self.l_down = [MaxDownsample(downsample_factors[level])
for level in range(self.num_levels-1)]
elif downsample_method.lower() == 'convolve':
self.l_down = [ConvDownsample(ngf*fmap_inc_factor**level,
ngf*fmap_inc_factor**(level + 1),
kernel_size_down[level][0],
downsample_factors[level],
activation=activation,
padding=padding_type,
norm_layer=norm_layer)
for level in range(self.num_levels - 1)]
else:
raise RuntimeError(f'Unknown downsampling method: {downsample_method}. Please use "max" or "convolve" instead.')

# Righthand up/crop/concatenate
self.r_up = [[Upsample(downsample_factors[level],
mode='nearest' if constant_upsample else 'transposed_conv',
# input_nc=ngf*fmap_inc_factor**(level + 1) + (level==1 and (add_noise is not False)),
output_nc=ngf*fmap_inc_factor**(level + 1),
crop_factor=crop_factors[level],
next_conv_kernel_sizes=kernel_size_up[level])
for level in range(self.num_levels - 1)
]for _ in range(num_heads)]

self.r_conv = [[ConvPass(
ngf*fmap_inc_factor**level +
ngf*fmap_inc_factor**(level + 1),
ngf*fmap_inc_factor**level
if output_nc is None or level != 0
else output_nc,
kernel_size_up[level],
activation=activation,
padding=padding_type,
residual=self.residual,
norm_layer=norm_layer)
for level in range(self.num_levels - 1)]
for _ in range(self.num_heads)]

def rec_forward(self, level,f_in):
# index of level in layer arrays
i = self.num_levels - level - 1

# convolve
f_left = self.l_conv[i](f_in)

# end of recursion
if level == 0:

if self.noise_layer is not None:
f_left = self.noise_layer(f_left)
fs_out = [f_left]*self.num_heads

else:

# down
g_in = self.l_down[i](f_left)

# nested levels
gs_out = self.rec_forward(level - 1, g_in)

# up, concat, and crop
fs_right = [
self.r_up[h][i](f_left, gs_out[h])
for h in range(self.num_heads)
]

# convolve
fs_out = [
self.r_conv[h][i](fs_right[h])
for h in range(self.num_heads)
]

return fs_out


def forward(self, x):

y = self.rec_forward(self.num_levels - 1, x)

if self.num_heads == 1:
return y[0]

return y

0 comments on commit 66b6b19

Please sign in to comment.