-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreshape.py
190 lines (150 loc) · 5.48 KB
/
reshape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Shape manipulators of arrays within the framework.
CITATION: These functions are based(but not directly taken from) on reshape implementations "https://github.com/bgavran/autodiff"
"""
#Import required packages
import numbers
import numpy as np
from .node import Node
class Pad(Node):
"""
Operation to add values on either sides of an array
"""
def __init__(self, node, pad_width, constant_values, name="Pad"):
"""
:param node:
:param pad_width: different than pad_width arg in np.pad, this one pads up to the length provided
:param constant_values:
:param name:
"""
super().__init__([node], name)
self.constant_values = constant_values
self.pad_width = pad_width
self.shape = np.pad(np.ones(self.node.shape),
self.pad_width,
mode="constant",
constant_values=self.constant_values).shape
self.node = self.children[0]
def _eval(self):
val = self.node()
return np.pad(val,
self.pad_width,
mode="constant",
constant_values=self.constant_values)
def _partial_derivative(self, wrt, previous_grad):
if self.node == wrt:
slice_val = [slice(pad[0], shp - pad[1]) for pad,shp in zip(self.pad_width, self.shape)]
return previous_grad[slice_val]
return 0
class Concat(Node):
"""
Class to concatenate two arrays
"""
def __init__(self, a, b, axis=0):
"""
params:
a,b: input arrays to be concatenated
axis: Axis along which concatenation should take place
"""
#Sanctity check
assert axis >= 0
super().__init__([a, b], name="Concat")
self.a, self.b = self.children
self.shape = list(self.a.shape)
self.axis = axis
self.shape[axis] += self.b.shape[axis]
def _eval(self):
a_val = self.a()
b_val = self.b()
return np.concatenate((a_val, b_val), axis=self.axis)
def _partial_derivative(self, wrt, previous_grad):
previous_grad = Reshape(previous_grad, self.shape)
split = self.a.shape[self.axis]
slice_val = [slice(None, None, None) for _ in range(self.axis + 1)]
if wrt == self.a:
slice_val[self.axis] = slice(None, split, None)
return previous_grad[slice_val]
elif wrt == self.b:
slice_val[self.axis] = slice(split, None, None)
return previous_grad[slice_val]
return 0
class ReduceSumKeepDims(Node):
"""
Operation which reduces the array keeping the dimensions by adding accordingly
"""
def __init__(self, node, axes):
"""
params:
node: The array which needs to be reduced
axes: Axes which need to be kept
"""
super().__init__([node], name="ReduceSumKeepDims")
self.axes = tuple(axes)
self.node = self.children[0]
self.shape = [1 if i in self.axes else shp for i, shp in enumerate(self.node.shape)]
def _eval(self):
return np.sum(self.node(), axis=self.axes, keepdims=True)
def _partial_derivative(self, wrt, previous_grad):
if self.node == wrt:
return previous_grad * np.ones(self.node.shape)
return 0
class Slice(Node):
"""
Operation to slice an array to required shape
"""
def __init__(self, node, slice_val, name="Slice"):
"""
params
node: The array to be sliced
slice_val :
"""
super().__init__([node], name)
self.node = self.children[0]
self.shape = np.zeros(self.node.shape)[self.slice_val].shape
self.slice_val = slice_val
def _eval(self):
val = self.node()
return val[self.slice_val]
def _partial_derivative(self, wrt, previous_grad):
if self.node == wrt:
grad = np.zeros(wrt.shape)
grad[self.slice_val] = previous_grad()
return grad
return 0
class Reshape(Node):
"""
Operation to reshape a given array
"""
def __init__(self, node, shape, name="Reshape"):
"""
params:
node:the array to be reshaped
shape : the shape to be reshaped into
"""
super().__init__([node], name)
self.shape = self.infer_shape(shape)
self.node = self.children[0]
def infer_shape(self, shape):
"""
Special method to infer shape
params:
shape: the output shape
returns the shape
"""
if isinstance(shape, numbers.Number):
return shape
if -1 in shape:
shape = list(shape)
for i in range(len(shape)):
if shape[i] == -1:
shape[i] = int(-np.prod(self.node.shape) / np.prod(shape))
return shape
def _eval(self):
node_val = self.node()
if isinstance(node_val, numbers.Number):
return np.broadcast_to(node_val, self.shape)
return np.reshape(node_val, self.shape)
def _partial_derivative(self, wrt, previous_grad):
if self.node == wrt:
return Reshape(previous_grad, self.node.shape)
return 0