-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnode.py
211 lines (163 loc) · 6.76 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
This is the definition of base class for all the operations in this autodiff package
Every variable and operation is an extension of "Node" which is a point in a computational graph.
This node and first extension of node "variable" as starters of a compuational graph are defined here.
CITATION: These two classes are directly taken from "https://github.com/bgavran/autodiff"
"""
#Import required packages
import time
import numbers
import numpy as np
from contextlib import contextmanager
class Node:
"""
Node is like the blue print for every operation and variable to be defined in the autodiff package
Every Node does two things:
1.Calculates the value
2.Passes the information about gradient
Therefore it is like a wrapper around Numpy which uses it as a Numerical Kernel.
"""
#Some attributes common to all the NOdes
epsilon = 1e-12
id = 0
context_list = []
def __init__(self, children, name="Node"):
"""
params:
children : attributes which enter a node to be operated (eg: two numbers which need to be added)
Node Id : to keep track of number of nodes instantiated to sort while collecting the derivatives
cached: To check whether overriding has taken place
shape: shape of the value of node if dealing with numpy arrays
"""
# wraps normal numbers into Variables
self.children = [child if isinstance(child, Node) else Variable(child) for child in children]
self.name = name
self.cached = None
self.shape = None
self.context_list = Node.context_list.copy()
self.id = Node.id
Node.id += 1
def _eval(self):
"""
This is left unimplemented and _ is used in the function name(Virtual method) such that:
1.It is a private method (Not accessed by the object directly)
2.It is overriden by all the child classes according to their operation (Add class adds two numbers and so on..)
:return: returns the value of the evaluated Node
"""
raise NotImplementedError()
def _partial_derivative(self, wrt, previous_grad):
"""
Method which calculates the partial derivative of self with respect to the wrt Node.
By defining this method without evaluation of any nodes, higher-order gradients
are available for free.
This is left unimplemented and _ is used in the function name(Virtual method) such that:
1.It is a private method (Not accessed by the object directly)
2.It is overriden by all the child classes according to their operation (derivative of x*y wrt x is y and so on..)
:param wrt: instance of Node, partial derivativative with respect to it
:param previous_grad: gradient with respect to self
:return: an instance of Node whose evaluation yields the partial derivative
"""
raise NotImplementedError()
def eval(self):
"""
The method to access the private _eval method.
Only accessed when the private method is appropriately overriden.
"""
#Sanity check
if self.cached is None:
self.cached = self._eval()
return self.cached
def partial_derivative(self, wrt, previous_grad):
"""
The method to access the private _partial derivative method.
Only accessed when the private method is appropriately overriden.
"""
with add_context(self.name + " PD" + " wrt " + str(wrt)):
return self._partial_derivative(wrt, previous_grad)
#Some magic methods defined to give a normalised look to the code(e.g: x+y looks better than Add(x,y))
def __call__(self, *args, **kwargs):
return self.eval()
def __str__(self):
return self.name # + " " + str(self.id)
def __add__(self, other):
from .ops import Add
return Add(self, other)
def __neg__(self):
from .ops import Negate
return Negate(self)
def __sub__(self, other):
return self.__add__(other.__neg__())
def __rsub__(self, other):
return self.__neg__().__add__(other)
def __mul__(self, other):
from .ops import Mul
return Mul(self, other)
def __matmul__(self, other):
from .high_level_ops import MatMul
return MatMul(self, other)
def __rmatmul__(self, other):
from .high_level_ops import MatMul
return MatMul(other, self)
def __imatmul__(self, other):
return self.__matmul__(other)
def __truediv__(self, other):
from .ops import Recipr
return self.__mul__(Recipr(other))
def __rtruediv__(self, other):
from .ops import Recipr
return Recipr(self).__mul__(other)
def __pow__(self, power, modulo=None):
from .ops import Pow
return Pow(self, power)
__rmul__ = __mul__
__radd__ = __add__
def __getitem__(self, item):
from .reshape import Slice
return Slice(self, item)
class Variable(Node):
"""
Starter of all Computational Graphs because every operation starts with a set of variables
Derivatives end here and the last calculated gradient is the total gradient of the compuational graph
"""
def __init__(self, value, name=None):
"""
If we forgot to name the variable, it is better to name it as the string of value itself.
Rest initialized same as Super class-Node
params:
value : The value to be stored in the variable
stored as private attribute and can be accessed by only equivalent public method to keep sanctity.
"""
if name is None:
name = str(value) # this op is really slow for np.arrays?!
super().__init__([], name)
if isinstance(value, numbers.Number):
self._value = np.array(value, dtype=np.float64)
else:
self._value = value
self.shape = self._value.shape
#Decorators to calculate the value by accessing the private variable
@property
def value(self):
return self._value
@value.setter
def value(self, val):
self.cached = self._value = val
def _eval(self):
"""
The overriden implementation
"""
return self._value
def _partial_derivative(self, wrt, previous_grad):
"""
The overriden implementation
"""
if self == wrt:
return previous_grad
return 0
@contextmanager
def add_context(ctx):
Node.context_list.append(ctx + "_" + str(time.time()))
try:
yield
finally:
del Node.context_list[-1]