From b3b8ca41cad4be6cacfccf9c262e9d6c1d23fe93 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Tue, 9 Aug 2022 10:53:22 -0400 Subject: [PATCH] Variable renaming for UNet #13 --- raygun/jax/networks/UNet.py | 39 +++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/raygun/jax/networks/UNet.py b/raygun/jax/networks/UNet.py index a99ba2bc..605c11ea 100644 --- a/raygun/jax/networks/UNet.py +++ b/raygun/jax/networks/UNet.py @@ -1,5 +1,6 @@ #%% import math +import numpy as np import jax import haiku as hk #%% @@ -8,7 +9,7 @@ class ConvPass(hk.Module): def __init__( self, input_nc, - ouput_nc, + output_nc, kernel_sizes, activation, padding='VALID', @@ -53,7 +54,7 @@ def __init__( try: layers.append( conv( - output_channels=ouput_nc, + output_channels=output_nc, kernel_shape=kernel_size, padding=padding, # padding_mode=padding_mode, @@ -75,7 +76,7 @@ def __init__( except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) - if norm_layers is not None: + if norm_layer is not None: layers.append(norm_layer(output_nc)) if not (residual and i == (len(kernel_sizes)-1)): @@ -112,15 +113,16 @@ def forward(self, x): class ConvDownsample(hk.Module): def __init__( - self, - input_nc, - output_nc, - kernel_sizes, - downsample_factor, - activation, - padding='valid', - padding_mode='reflect', - norm_layer=None): + self, + input_nc, + output_nc, + kernel_sizes, + downsample_factor, + activation, + padding='valid', + padding_mode='reflect', + norm_layer=None, + data_format='NCDHW'): super().__init__() @@ -156,8 +158,8 @@ def __init__( try: layers.append( conv( - output_channels=ouput_nc, - kernel_shape=kernel_size, + output_channels=output_nc, + kernel_shape=kernel_sizes, stride=downsample_factor, padding=padding, # padding_mode=padding_mode, @@ -166,7 +168,7 @@ def __init__( except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) - if norm_layers is not None: + if norm_layer is not None: layers.append(norm_layer(output_nc)) layers.append(self.activation) @@ -184,7 +186,7 @@ def __init__( super().__init__() - self.dims = len(dowmsample_factor) + self.dims = len(downsample_factor) self.downsample_factor = downsample_factor self.flexible = flexible @@ -211,7 +213,9 @@ def check_mismatch(self, size): size, self.downsample_factor, self.dims - d)) + return + #%% class Upsample(hk.module): @@ -345,4 +349,5 @@ def __init__(self, # fov=(1, 1, 1), # voxel_size=(1, 1, 1), # num_fmaps_out=None - ): \ No newline at end of file + ): + pass \ No newline at end of file