Skip to content

Commit

Permalink
Variable renaming for UNet #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 9, 2022
1 parent 8b024df commit b3b8ca4
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions raygun/jax/networks/UNet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#%%
import math
import numpy as np
import jax
import haiku as hk
#%%
Expand All @@ -8,7 +9,7 @@ class ConvPass(hk.Module):
def __init__(
self,
input_nc,
ouput_nc,
output_nc,
kernel_sizes,
activation,
padding='VALID',
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
try:
layers.append(
conv(
output_channels=ouput_nc,
output_channels=output_nc,
kernel_shape=kernel_size,
padding=padding,
# padding_mode=padding_mode,
Expand All @@ -75,7 +76,7 @@ def __init__(
except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

if norm_layers is not None:
if norm_layer is not None:
layers.append(norm_layer(output_nc))

if not (residual and i == (len(kernel_sizes)-1)):
Expand Down Expand Up @@ -112,15 +113,16 @@ def forward(self, x):
class ConvDownsample(hk.Module):

def __init__(
self,
input_nc,
output_nc,
kernel_sizes,
downsample_factor,
activation,
padding='valid',
padding_mode='reflect',
norm_layer=None):
self,
input_nc,
output_nc,
kernel_sizes,
downsample_factor,
activation,
padding='valid',
padding_mode='reflect',
norm_layer=None,
data_format='NCDHW'):

super().__init__()

Expand Down Expand Up @@ -156,8 +158,8 @@ def __init__(
try:
layers.append(
conv(
output_channels=ouput_nc,
kernel_shape=kernel_size,
output_channels=output_nc,
kernel_shape=kernel_sizes,
stride=downsample_factor,
padding=padding,
# padding_mode=padding_mode,
Expand All @@ -166,7 +168,7 @@ def __init__(
except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

if norm_layers is not None:
if norm_layer is not None:
layers.append(norm_layer(output_nc))

layers.append(self.activation)
Expand All @@ -184,7 +186,7 @@ def __init__(

super().__init__()

self.dims = len(dowmsample_factor)
self.dims = len(downsample_factor)
self.downsample_factor = downsample_factor
self.flexible = flexible

Expand All @@ -211,7 +213,9 @@ def check_mismatch(self, size):
size,
self.downsample_factor,
self.dims - d))

return

#%%
class Upsample(hk.module):

Expand Down Expand Up @@ -345,4 +349,5 @@ def __init__(self,
# fov=(1, 1, 1),
# voxel_size=(1, 1, 1),
# num_fmaps_out=None
):
):
pass

0 comments on commit b3b8ca4

Please sign in to comment.