forked from d1ggs/cycleGAN-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
25 lines (19 loc) · 897 Bytes
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
from keras.layers import Layer, InputSpec
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import ZeroPadding2D
import numpy as np
class ReflectionPadding2D(Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
print (input_shape)
return (input_shape[0], input_shape[1], input_shape[2] + 2 * self.padding[0], input_shape[3] + 2 * self.padding[1])
def call(self, x, mask=None):
w_pad,h_pad = self.padding
return tf.pad(x, [[0,0], [0,0] ,[h_pad,h_pad], [w_pad,w_pad]], 'REFLECT')