1
+ import numpy as np
2
+
1
3
#Class to hold weight or bias arrays for layers
2
4
class Weight :
3
5
def __init__ (self , id , nparray = None ):
@@ -14,36 +16,57 @@ def getParams(self):
14
16
if self .values .any () == None :
15
17
return ''
16
18
17
- shape = self .values .shape
18
-
19
- if len (shape ) == 1 :
20
- param_string = 'float ' + self .identifier + \
21
- '[' + str (shape [0 ]) + '] = {'
22
-
23
- param_string = param_string + '{'
24
- for i in range (shape [0 ]):
25
- param_string = param_string + str (self .values [i ])
19
+ flat = self .values .flatten ()
20
+ size = flat .shape [0 ]
26
21
27
- if i < shape [ 0 ] - 1 :
28
- param_string = param_string + ', '
22
+ param_string = 'float ' + self . identifier + \
23
+ '[' + str ( size ) + '] = { '
29
24
30
- param_string = param_string + '}\n '
25
+ for i in range (size ):
26
+ param_string = param_string + str (flat [i ])
31
27
32
- else :
33
- param_string = 'float ' + self .identifier + \
34
- '[' + str (shape [0 ]) + '][' + str (shape [1 ]) + '] = {'
28
+ if i < size - 1 :
29
+ param_string = param_string + ', '
35
30
36
- for i in range (shape [0 ]):
37
- param_string = param_string + '{'
38
- for j in range (shape [1 ]):
39
- param_string = param_string + str (self .values [i ][j ])
31
+ param_string = param_string + '}\n \n '
40
32
41
- if j < shape [1 ]- 1 :
42
- param_string = param_string + ', '
43
-
44
- if i < shape [0 ]- 1 :
45
- param_string = param_string + '},\n \t \t \t \t \t \t \t \t \t '
33
+ return param_string
46
34
47
- param_string = param_string + '}\n '
48
35
49
- return param_string
36
+ #Deprecated
37
+ # def getParams(self):
38
+ # if self.values.any() == None:
39
+ # return ''
40
+ #
41
+ # shape = self.values.shape
42
+ #
43
+ # if len(shape) == 1:
44
+ # param_string = 'float ' + self.identifier +\
45
+ # '[' + str(shape[0]) + '] = {'
46
+ #
47
+ # for i in range(shape[0]):
48
+ # param_string = param_string + str(self.values[i])
49
+ #
50
+ # if i<shape[0]-1:
51
+ # param_string = param_string + ', '
52
+ #
53
+ # param_string = param_string + '}\n'
54
+ #
55
+ # else:
56
+ # param_string = 'float ' + self.identifier +\
57
+ # '[' + str(shape[0]) + '][' + str(shape[1]) + '] = {'
58
+ #
59
+ # for i in range(shape[0]):
60
+ # param_string = param_string + '{'
61
+ # for j in range(shape[1]):
62
+ # param_string = param_string + str(self.values[i][j])
63
+ #
64
+ # if j<shape[1]-1:
65
+ # param_string = param_string + ', '
66
+ #
67
+ # if i<shape[0]-1:
68
+ # param_string = param_string + '},\n\t\t\t\t\t\t\t\t\t'
69
+ #
70
+ # param_string = param_string + '}\n'
71
+ #
72
+ # return param_string
0 commit comments