diff --git a/raygun/jax/networks/UNet.py b/raygun/jax/networks/UNet.py index 605c11ea..31bfcb16 100644 --- a/raygun/jax/networks/UNet.py +++ b/raygun/jax/networks/UNet.py @@ -1,9 +1,9 @@ -#%% import math import numpy as np import jax import haiku as hk -#%% + + class ConvPass(hk.Module): def __init__( @@ -14,7 +14,7 @@ def __init__( activation, padding='VALID', residual=False, - padding_mode='reflect', + # padding_mode='reflect', norm_layer=None, data_format='NCDHW'): @@ -109,7 +109,8 @@ def forward(self, x): else: init_x = self.x_init_map(x) return self.activation(init_x + res) -#%% + + class ConvDownsample(hk.Module): def __init__( @@ -120,7 +121,7 @@ def __init__( downsample_factor, activation, padding='valid', - padding_mode='reflect', + # padding_mode='reflect', norm_layer=None, data_format='NCDHW'): @@ -176,7 +177,8 @@ def __init__( def forward(self, x): return self.conv_pass(x) -#%% + + class MaxDownsample(hk.Module): # TODO: check data format type def __init__( @@ -213,10 +215,9 @@ def check_mismatch(self, size): size, self.downsample_factor, self.dims - d)) - return -#%% + class Upsample(hk.module): def __init__( @@ -226,7 +227,7 @@ def __init__( output_nc=None, crop_factor=None, next_conv_kernel_sizes=None, - data_format=None): + data_format='NCDHW'): super().__init__() @@ -329,6 +330,7 @@ def forward(self, f_left, g_out): return jax.lax.concatenate((f_cropped, g_cropped), dimension=1) + class UNet(hk.Module): def __init__(self, @@ -338,6 +340,7 @@ def __init__(self, kernel_size_down=None, kernel_size_up=None, activation='relu', + input_nc=None, output_nc=None, num_heads=1, constant_upsample=False, @@ -350,4 +353,139 @@ def __init__(self, # voxel_size=(1, 1, 1), # num_fmaps_out=None ): - pass \ No newline at end of file + + super().__init__() + self.ndims = len(downsample_factors[0]) + self.num_levels = len(downsample_factors) + 1 + self.num_heads = num_heads + self.input_nc = input_nc + self.output_nc = output_nc if output_nc else ngf + self.residual = residual + # if add_noise == 'param': # add noise feature if necessary + # self.noise_layer = ParameterizedNoiseBlock() + # elif add_noise: + # self.noise_layer = NoiseBlock() # TODO add utils methods + # else: + # self.noise_layer = None + + if kernel_size_down is None: + kernel_size_down = [[(3,)*self.ndims, (3,)*self.ndims]]*self.num_levels + if kernel_size_up is None: + kernel_size_up = [[(3,)*self.ndims, (3,)*self.ndims]]*(self.num_levels - 1) + + crop_factors = [] + factor_product = None + for factor in downsample_factors[::-1]: + if padding_type.lower() == 'valid': + if factor_product is None: + factor_product = list(factor) + else: + factor_product = list( + f*ff + for f, ff in zip(factor, factor_product)) + elif padding_type.lower() == 'same': + factor_product = None + else: + raise f'Invalid padding_type option: {padding_type}' + crop_factors.append(factor_product) + crop_factors = crop_factors[::-1] + + # Left pass + self.l_conv = [ConvPass(input_nc + if level == 0 + else ngf*fmap_inc_factor**(level - (downsample_method.lower() == 'max')), + ngf*fmap_inc_factor**level, + kernel_size_down[level], + activation=activation, + padding=padding_type, + residual=self.residual, + norm_layer=norm_layer) + for level in range(self.num_levels) + ] + self.dims = self.l_conv[0].dims + + # Left downsample + if downsample_method.lower() == 'max': + self.l_down = [MaxDownsample(downsample_factors[level]) + for level in range(self.num_levels-1)] + elif downsample_method.lower() == 'convolve': + self.l_down = [ConvDownsample(ngf*fmap_inc_factor**level, + ngf*fmap_inc_factor**(level + 1), + kernel_size_down[level][0], + downsample_factors[level], + activation=activation, + padding=padding_type, + norm_layer=norm_layer) + for level in range(self.num_levels - 1)] + else: + raise RuntimeError(f'Unknown downsampling method: {downsample_method}. Please use "max" or "convolve" instead.') + + # Righthand up/crop/concatenate + self.r_up = [[Upsample(downsample_factors[level], + mode='nearest' if constant_upsample else 'transposed_conv', + # input_nc=ngf*fmap_inc_factor**(level + 1) + (level==1 and (add_noise is not False)), + output_nc=ngf*fmap_inc_factor**(level + 1), + crop_factor=crop_factors[level], + next_conv_kernel_sizes=kernel_size_up[level]) + for level in range(self.num_levels - 1) + ]for _ in range(num_heads)] + + self.r_conv = [[ConvPass( + ngf*fmap_inc_factor**level + + ngf*fmap_inc_factor**(level + 1), + ngf*fmap_inc_factor**level + if output_nc is None or level != 0 + else output_nc, + kernel_size_up[level], + activation=activation, + padding=padding_type, + residual=self.residual, + norm_layer=norm_layer) + for level in range(self.num_levels - 1)] + for _ in range(self.num_heads)] + + def rec_forward(self, level,f_in): + # index of level in layer arrays + i = self.num_levels - level - 1 + + # convolve + f_left = self.l_conv[i](f_in) + + # end of recursion + if level == 0: + + if self.noise_layer is not None: + f_left = self.noise_layer(f_left) + fs_out = [f_left]*self.num_heads + + else: + + # down + g_in = self.l_down[i](f_left) + + # nested levels + gs_out = self.rec_forward(level - 1, g_in) + + # up, concat, and crop + fs_right = [ + self.r_up[h][i](f_left, gs_out[h]) + for h in range(self.num_heads) + ] + + # convolve + fs_out = [ + self.r_conv[h][i](fs_right[h]) + for h in range(self.num_heads) + ] + + return fs_out + + + def forward(self, x): + + y = self.rec_forward(self.num_levels - 1, x) + + if self.num_heads == 1: + return y[0] + + return y \ No newline at end of file