Skip to content

Commit 30b7e2f

Browse files
committed
Will I ever actually be ready to publish, hopefully now I am. Weight parameter stuff is fixed.
1 parent 65e668e commit 30b7e2f

File tree

1 file changed

+49
-26
lines changed

1 file changed

+49
-26
lines changed
+49-26
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
#Class to hold weight or bias arrays for layers
24
class Weight:
35
def __init__(self, id, nparray=None):
@@ -14,36 +16,57 @@ def getParams(self):
1416
if self.values.any() == None:
1517
return ''
1618

17-
shape = self.values.shape
18-
19-
if len(shape) == 1:
20-
param_string = 'float ' + self.identifier +\
21-
'[' + str(shape[0]) + '] = {'
22-
23-
param_string = param_string + '{'
24-
for i in range(shape[0]):
25-
param_string = param_string + str(self.values[i])
19+
flat = self.values.flatten()
20+
size = flat.shape[0]
2621

27-
if i<shape[0]-1:
28-
param_string = param_string + ', '
22+
param_string = 'float ' + self.identifier +\
23+
'[' + str(size) + '] = {'
2924

30-
param_string = param_string + '}\n'
25+
for i in range(size):
26+
param_string = param_string + str(flat[i])
3127

32-
else:
33-
param_string = 'float ' + self.identifier +\
34-
'[' + str(shape[0]) + '][' + str(shape[1]) + '] = {'
28+
if i<size-1:
29+
param_string = param_string + ', '
3530

36-
for i in range(shape[0]):
37-
param_string = param_string + '{'
38-
for j in range(shape[1]):
39-
param_string = param_string + str(self.values[i][j])
31+
param_string = param_string + '}\n\n'
4032

41-
if j<shape[1]-1:
42-
param_string = param_string + ', '
43-
44-
if i<shape[0]-1:
45-
param_string = param_string + '},\n\t\t\t\t\t\t\t\t\t'
33+
return param_string
4634

47-
param_string = param_string + '}\n'
4835

49-
return param_string
36+
#Deprecated
37+
# def getParams(self):
38+
# if self.values.any() == None:
39+
# return ''
40+
#
41+
# shape = self.values.shape
42+
#
43+
# if len(shape) == 1:
44+
# param_string = 'float ' + self.identifier +\
45+
# '[' + str(shape[0]) + '] = {'
46+
#
47+
# for i in range(shape[0]):
48+
# param_string = param_string + str(self.values[i])
49+
#
50+
# if i<shape[0]-1:
51+
# param_string = param_string + ', '
52+
#
53+
# param_string = param_string + '}\n'
54+
#
55+
# else:
56+
# param_string = 'float ' + self.identifier +\
57+
# '[' + str(shape[0]) + '][' + str(shape[1]) + '] = {'
58+
#
59+
# for i in range(shape[0]):
60+
# param_string = param_string + '{'
61+
# for j in range(shape[1]):
62+
# param_string = param_string + str(self.values[i][j])
63+
#
64+
# if j<shape[1]-1:
65+
# param_string = param_string + ', '
66+
#
67+
# if i<shape[0]-1:
68+
# param_string = param_string + '},\n\t\t\t\t\t\t\t\t\t'
69+
#
70+
# param_string = param_string + '}\n'
71+
#
72+
# return param_string

0 commit comments

Comments
 (0)