@@ -50,20 +50,28 @@ class ParameterStructure:
50
50
51
51
def __init__ (self ):
52
52
self .parameters = dict ()
53
- self .organization = dict () # Lists parameter grouping and indices
53
+ self .organization = dict ()
54
+ self .indices = dict () # Lists parameter grouping and indices
54
55
self .param_idx = 0
56
+ self .set_params = []
57
+
58
+ def __reduce__ (self ):
59
+ return (ParameterStructure , (self .progress_int ,))
60
+
55
61
56
62
def add_parameter (self , name ):
57
63
self .organization [self .param_idx ] = 1
58
64
self .parameters [self .param_idx ] = name
59
- setattr (self , name + "_index" , self .param_idx )
65
+ self .indices [name + "_index" ] = self .param_idx
66
+ # setattr(self, name+ "_index", self.param_idx)
60
67
self .param_idx += 1
61
68
62
69
def add_multiple_parameters (self , name , amount ):
63
70
self .organization [self .param_idx ] = amount
64
71
for i in range (amount ):
65
72
self .parameters [self .param_idx ] = name + "_" + str (i )
66
- setattr (self , name + "_" + str (i ) + "_index" , self .param_idx )
73
+ self .indices [name + "_" + str (i ) + "_index" ] = self .param_idx
74
+ #setattr(self, name + "_" + str(i) + "_index", self.param_idx)
67
75
self .param_idx += 1
68
76
69
77
def has_parameter (self , name ):
@@ -88,7 +96,18 @@ def __str__(self):
88
96
# When operating, retrieve the weights from param
89
97
def load_params (self , params ):
90
98
for key , name in self .parameters .items (): # This is a parameter name
91
- setattr (self , name , params [getattr (self , name + "_index" )]) # this is an index
99
+ #setattr(self, name, params[getattr(self, name+ "_index")]) # this is an index
100
+ setattr (self , name , params [self .indices [name + "_index" ]])
101
+
102
+ def save (self ):
103
+ return self .__dict__ ['indices' ]
104
+
105
+ def set_params (self ,name , value ):
106
+ if self .set_params .len == 0 :
107
+ self .set_params = np .zeros (self .param_idx ,1 )
108
+ index = getattr (self , 'x_goal' + "_index" )
109
+ self .set_params [index ] = value
110
+
92
111
93
112
class WeightStructure :
94
113
0 commit comments