Skip to content

Commit 4e0a628

Browse files
committed
save parameter indices of solver generation and add example
1 parent 88e8bf7 commit 4e0a628

File tree

7 files changed

+498
-321
lines changed

7 files changed

+498
-321
lines changed

base_mpc/planner/mpcPlanner.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import sys
2+
sys.path.append('/home/luzia/code/forces_pro_client/') # todo remove
3+
import pickle
4+
import forcespro.nlp
5+
6+
class MPCPlanner(object):
7+
8+
def __init__(self, robot_type, solver_dir):
9+
10+
11+
self.solver = forcespro.nlp.Solver.from_directory(solver_dir + "/AlbertFORCESNLPsolver") #todo add robot_type
12+
13+
with open(solver_dir + "/Albertparams.pkl", 'rb') as file:
14+
self.parameters = pickle.load(file)
15+
16+
17+
18+
def updateDynamicObstacles(self, obstArray):
19+
print('not implemented yet')
20+
def setx0(self, xinit):
21+
print('not implemented yet')
22+
def solve(self, current_state):
23+
#self._xinit = current_state[0: self._nx]
24+
self.setX0(self._xinit)
25+
print('not implemented yet')

base_mpc/python_solver_generation/base_mpc_solver.py

+9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import copy
20+
import pickle
2021
import forcespro.nlp
2122

2223
import dynamics
@@ -158,8 +159,16 @@ def objective_with_stage_index(stage_idx):
158159
#output1 = ("sol", [], [])
159160
generated_solver = solver.generate_solver(options) #, outputs
160161

162+
#save settings
163+
params_dict = settings.params.save()
164+
file_name = dir_path + "/Albertparams.pkl"
165+
166+
with open(file_name, 'wb') as outp:
167+
pickle.dump(params_dict, outp, pickle.HIGHEST_PROTOCOL)
168+
161169
# Move the solver up a directory for convenience
162170
if os.path.isdir(solver_path):
163171
shutil.move(solver_path, new_solver_path)
164172

165173

174+

base_mpc/python_solver_generation/helpers.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,28 @@ class ParameterStructure:
5050

5151
def __init__(self):
5252
self.parameters = dict()
53-
self.organization = dict() # Lists parameter grouping and indices
53+
self.organization = dict()
54+
self.indices = dict() # Lists parameter grouping and indices
5455
self.param_idx = 0
56+
self.set_params = []
57+
58+
def __reduce__(self):
59+
return (ParameterStructure, (self.progress_int,))
60+
5561

5662
def add_parameter(self, name):
5763
self.organization[self.param_idx] = 1
5864
self.parameters[self.param_idx] = name
59-
setattr(self, name+ "_index", self.param_idx)
65+
self.indices[name+ "_index"] = self.param_idx
66+
# setattr(self, name+ "_index", self.param_idx)
6067
self.param_idx += 1
6168

6269
def add_multiple_parameters(self, name, amount):
6370
self.organization[self.param_idx] = amount
6471
for i in range(amount):
6572
self.parameters[self.param_idx] = name + "_" + str(i)
66-
setattr(self, name + "_" + str(i) + "_index", self.param_idx)
73+
self.indices[name + "_" + str(i) + "_index"] = self.param_idx
74+
#setattr(self, name + "_" + str(i) + "_index", self.param_idx)
6775
self.param_idx += 1
6876

6977
def has_parameter(self, name):
@@ -88,7 +96,18 @@ def __str__(self):
8896
# When operating, retrieve the weights from param
8997
def load_params(self, params):
9098
for key, name in self.parameters.items(): # This is a parameter name
91-
setattr(self, name, params[getattr(self, name+ "_index")]) # this is an index
99+
#setattr(self, name, params[getattr(self, name+ "_index")]) # this is an index
100+
setattr(self, name, params[self.indices[name + "_index"]])
101+
102+
def save(self):
103+
return self.__dict__['indices']
104+
105+
def set_params(self,name, value):
106+
if self.set_params.len == 0:
107+
self.set_params = np.zeros(self.param_idx,1)
108+
index = getattr(self, 'x_goal'+ "_index")
109+
self.set_params[index] = value
110+
92111

93112
class WeightStructure:
94113

0 commit comments

Comments
 (0)