Skip to content

Commit

Permalink
Update PSL MTL
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhang2523 committed Sep 3, 2024
1 parent 3289940 commit 4ac3c91
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 37 deletions.
14 changes: 8 additions & 6 deletions libmoon/solver/mobo/methods/base_solver_mobod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@


class MOBOD(object):
def __init__(self, problem, n_init, MAX_FE, BATCH_SIZE):
def __init__(self, problem, x_init, MAX_FE, BATCH_SIZE):
self.n_var = problem.n_var
self.n_obj = problem.n_obj
self.n_init = n_init
self.x_init = x_init
self.n_init = x_init.shape[0]
self.MAX_FE = MAX_FE
self.BATCH_SIZE = BATCH_SIZE
self.max_iter = math.ceil((MAX_FE - n_init)/BATCH_SIZE)
self.max_iter = math.ceil((MAX_FE - self.n_init)/BATCH_SIZE)
self.problem = problem
self.bounds = torch.from_numpy(np.vstack((problem.lbound,problem.ubound)))

Expand All @@ -45,8 +46,8 @@ def _step(self,batch_size):

def solve(self):
# get initial samples
x_init = torch.from_numpy(lhs(self.n_var,samples=self.n_init))
x_init = self.bounds[0,...] + (self.bounds[1,...] - self.bounds[0,...]) * x_init
# x_init = torch.from_numpy(lhs(self.n_var,samples=self.n_init))
x_init = self.bounds[0,...] + (self.bounds[1,...] - self.bounds[0,...]) * self.x_init
y_init = self.problem.evaluate(x_init)
self._record(x_init, y_init)
# generate reference vectors
Expand All @@ -59,6 +60,7 @@ def solve(self):
pass

hv_dict = {}
hv_dict[self.n_init] = compute_hv(self.y.detach().cpu().numpy(), self.problem.problem_name)
for i in tqdm(range(self.max_iter)):
# Scale the objective values
train_x = self.x.clone()
Expand All @@ -82,7 +84,7 @@ def solve(self):
self._record(new_x, new_obj)
hv_val = compute_hv(self.y.detach().cpu().numpy(), self.problem.problem_name)
print('Iteration: %d, HV: %.4f'%(i,hv_val))
hv_dict[i*batch_size + self.n_init] = hv_val
hv_dict[(i+1)*batch_size + self.n_init] = hv_val

res = {}
res['x'] = self.x.detach().numpy()
Expand Down
17 changes: 10 additions & 7 deletions libmoon/solver/mobo/methods/base_solver_pslmobo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
Computation, 28(2): 432-444, 2024.
'''
class PSLMOBO(object):
def __init__(self, problem, n_init, MAX_FE, BATCH_SIZE):
def __init__(self, problem, x_init, MAX_FE, BATCH_SIZE):
self.n_var = problem.n_var
self.n_obj = problem.n_obj
self.n_init = n_init
self.x_init = x_init
self.n_init = x_init.shape[0]
self.MAX_FE = MAX_FE
self.BATCH_SIZE = BATCH_SIZE
self.max_iter = math.ceil((MAX_FE - n_init)/BATCH_SIZE)
self.max_iter = math.ceil((MAX_FE - self.n_init)/BATCH_SIZE)
self.problem = problem
self.bounds = torch.from_numpy(np.vstack((problem.lbound,problem.ubound)))
self.x = None
Expand All @@ -41,12 +42,14 @@ def _batch_selection(self, batch_size):

def solve(self):
# get initial samples
x_init = torch.from_numpy(lhs(self.n_var,samples=self.n_init))
x_init = self.bounds[0,...] + (self.bounds[1,...] - self.bounds[0,...]) * x_init
x_init = self.bounds[0,...] + (self.bounds[1,...] - self.bounds[0,...]) * self.x_init
y_init = self.problem.evaluate(x_init)
self._record(x_init, y_init)

hv_dict = {}
hv_dict[self.n_init] = compute_hv(self.y.detach().cpu().numpy(), self.problem.problem_name)
print('Iteration: %d, HV: %.4f' % (0, compute_hv(self.y.detach().cpu().numpy(), self.problem.problem_name)))

for i in tqdm(range(self.max_iter)):
# solution normalization x: [0,1]^d, y: [0,1]^m
train_x = normalize(self.x, self.bounds)
Expand Down Expand Up @@ -74,8 +77,8 @@ def solve(self):
# self.y: HV
# print()
hv_val = compute_hv(self.y.detach().cpu().numpy(), self.problem.problem_name)
print('Iteration: %d, HV: %.4f' % (i, hv_val))
hv_dict[i * batch_size + self.n_init] = hv_val
print('Iteration: %d, HV: %.4f' % ((i+1), hv_val))
hv_dict[(i+1) * batch_size + self.n_init] = hv_val

res = {}
res['x'] = self.x.detach().numpy()
Expand Down
4 changes: 2 additions & 2 deletions libmoon/solver/mobo/methods/dirhvego_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
)

class DirHVEGOSolver(MOBOD):
def __init__(self, problem, n_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, n_init, MAX_FE, BATCH_SIZE)
def __init__(self, problem, x_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, x_init, MAX_FE, BATCH_SIZE)
self.solver_name = 'dirhvego'

def _get_acquisition(self, u, sigma, ref_vec, pref_inc):
Expand Down
4 changes: 2 additions & 2 deletions libmoon/solver/mobo/methods/psldirhvei_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
phi, # Standard normal PDF
)
class PSLDirHVEISolver(PSLMOBO):
def __init__(self, problem, n_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, n_init, MAX_FE, BATCH_SIZE)
def __init__(self, problem, x_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, x_init, MAX_FE, BATCH_SIZE)
self.solver_name = 'psldirhvei'


Expand Down
4 changes: 2 additions & 2 deletions libmoon/solver/mobo/methods/pslmobo_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


class PSLMOBOSolver(PSLMOBO):
def __init__(self, problem, n_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, n_init, MAX_FE, BATCH_SIZE)
def __init__(self, problem, x_init, MAX_FE, BATCH_SIZE):
super().__init__(problem, x_init, MAX_FE, BATCH_SIZE)
self.solver_name = 'pslmobo'
self.coef_lcb = 0.1 # coefficient of LCB

Expand Down
175 changes: 175 additions & 0 deletions libmoon/solver/mobo/tester/T
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
tensor([[5.4564e-01, 2.5178e-01, 3.4049e-01, 1.5126e-01, 4.2823e-02, 5.5907e-01,
9.4332e-01, 1.8682e-01],
[5.9989e-01, 5.0448e-01, 6.9477e-01, 7.8029e-02, 7.9084e-01, 2.1021e-01,
6.1106e-01, 2.2316e-01],
[9.8930e-01, 5.5632e-01, 7.1048e-02, 4.3747e-01, 9.7104e-01, 1.4319e-01,
3.5873e-01, 2.9765e-01],
[4.6178e-01, 3.7950e-01, 9.9695e-01, 8.5571e-01, 3.8517e-01, 9.6339e-01,
5.4158e-01, 1.4455e-01],
[5.1072e-01, 9.3353e-01, 5.0304e-01, 8.4203e-01, 4.1580e-01, 1.2021e-01,
1.5015e-01, 3.2584e-01],
[8.0147e-01, 5.9792e-01, 4.6759e-01, 9.8229e-01, 2.1590e-01, 9.7152e-01,
6.5522e-01, 3.8043e-01],
[2.1424e-01, 2.2604e-01, 9.7829e-01, 5.1677e-01, 5.7157e-01, 7.5528e-01,
4.0852e-01, 8.7929e-01],
[1.5886e-01, 6.2594e-01, 6.1224e-01, 1.8328e-01, 7.0364e-01, 7.0052e-01,
1.7617e-01, 8.0931e-01],
[4.5601e-01, 1.0931e-01, 5.9617e-01, 7.4635e-01, 5.7597e-01, 5.6939e-01,
8.4941e-01, 3.5765e-01],
[1.3683e-01, 8.5432e-01, 7.8409e-01, 5.8654e-01, 8.9857e-01, 9.9433e-01,
4.7238e-01, 6.7298e-01],
[3.9164e-02, 8.2202e-01, 8.2920e-01, 9.9053e-01, 3.2475e-01, 1.5400e-01,
3.7472e-01, 8.2426e-01],
[3.5735e-01, 9.2839e-02, 3.0201e-01, 8.8985e-01, 7.1424e-01, 8.4473e-01,
2.2651e-01, 6.5723e-01],
[1.4708e-01, 6.6120e-01, 2.4479e-01, 1.1128e-01, 5.5993e-01, 5.2384e-02,
5.6449e-01, 9.5064e-01],
[9.6738e-01, 9.5760e-01, 9.1797e-01, 2.3041e-01, 9.2592e-01, 8.2725e-02,
5.8280e-01, 8.3149e-01],
[3.9589e-01, 6.6760e-01, 8.9793e-01, 4.7514e-01, 1.5665e-01, 3.7526e-01,
5.5631e-01, 5.6787e-01],
[4.2713e-01, 8.0043e-01, 8.2631e-01, 1.5215e-04, 7.5643e-01, 8.8009e-01,
6.8101e-01, 8.5899e-01],
[8.9502e-01, 5.8983e-02, 5.3446e-01, 9.5788e-01, 1.0664e-01, 2.8699e-01,
8.1303e-01, 2.4121e-01],
[7.7222e-01, 7.1832e-01, 2.1827e-01, 5.5569e-01, 4.6675e-01, 7.3842e-01,
9.7533e-01, 7.5117e-01],
[1.0115e-01, 8.5015e-01, 7.2159e-01, 3.8866e-01, 6.7777e-01, 2.9430e-01,
8.8070e-01, 1.6305e-01],
[8.4598e-01, 4.0383e-01, 9.6228e-01, 2.6301e-01, 6.9619e-02, 3.9552e-01,
2.6818e-02, 9.1998e-01],
[9.2710e-01, 7.3698e-01, 3.4518e-01, 3.4033e-02, 5.3151e-01, 7.2537e-01,
1.4887e-01, 5.2941e-01],
[3.9038e-01, 7.3332e-02, 2.8911e-01, 8.2457e-01, 1.8469e-01, 1.6502e-01,
2.5718e-01, 7.8294e-01],
[7.4249e-01, 3.3246e-01, 4.3042e-01, 6.5701e-01, 4.0808e-01, 2.3389e-01,
8.8749e-01, 8.0973e-02],
[6.2958e-01, 9.7580e-03, 9.2511e-01, 1.8884e-01, 3.9342e-01, 8.3820e-01,
6.4938e-01, 1.5430e-01],
[2.8315e-01, 9.6886e-01, 4.8732e-01, 6.2298e-01, 2.2477e-01, 3.1463e-01,
7.1176e-01, 3.7458e-01],
[4.4479e-01, 2.6634e-01, 5.1463e-01, 3.1616e-01, 6.4787e-01, 6.7250e-01,
9.3650e-01, 3.5543e-01],
[5.3864e-01, 8.0659e-01, 3.4171e-02, 3.6026e-01, 8.0362e-01, 2.0679e-01,
7.3067e-01, 5.4985e-03],
[3.0997e-01, 3.7581e-01, 6.5652e-01, 5.2184e-01, 4.8242e-01, 1.3375e-01,
4.4529e-01, 5.8201e-01],
[9.7954e-01, 6.8415e-01, 8.4078e-01, 4.8797e-02, 3.3569e-01, 6.2714e-02,
1.6111e-01, 7.7085e-01],
[3.3680e-03, 7.6747e-01, 6.4944e-01, 3.9185e-01, 9.3486e-01, 7.8344e-01,
1.1285e-02, 4.0100e-01],
[7.4759e-01, 4.9034e-01, 1.7356e-01, 7.6577e-01, 2.3256e-01, 4.5369e-01,
9.1675e-01, 6.5151e-01],
[2.5252e-01, 7.5448e-01, 1.5880e-01, 9.1456e-01, 8.4912e-01, 6.4668e-01,
3.0520e-01, 2.7520e-01],
[8.7043e-01, 1.3437e-01, 1.6130e-01, 8.1399e-01, 9.1227e-01, 8.6603e-01,
3.1638e-01, 4.7521e-01],
[8.7519e-01, 2.9072e-01, 3.8732e-01, 3.0566e-01, 3.1391e-01, 8.5468e-01,
4.2474e-01, 4.1291e-01],
[4.7718e-01, 4.7549e-01, 1.4320e-01, 8.3171e-01, 1.4779e-01, 1.8680e-01,
8.6109e-01, 3.0197e-01],
[2.6710e-01, 3.4116e-01, 2.8667e-01, 7.2235e-01, 2.0660e-01, 7.7012e-01,
6.4012e-01, 4.9703e-01],
[3.7445e-01, 9.4426e-01, 7.7868e-01, 6.8825e-01, 3.9912e-03, 7.6867e-01,
2.5231e-01, 6.9399e-01],
[9.5857e-01, 9.2549e-01, 6.0366e-01, 8.7915e-01, 8.1004e-01, 9.4514e-01,
2.0967e-01, 5.9868e-01],
[9.0662e-01, 5.2985e-01, 8.7214e-02, 1.3068e-02, 8.7872e-01, 3.3726e-01,
7.7500e-01, 8.0132e-01],
[8.3442e-02, 1.5987e-01, 8.0244e-01, 7.9861e-01, 4.5244e-01, 7.6977e-02,
4.9497e-01, 9.8540e-01],
[1.7211e-02, 4.3963e-01, 3.9878e-01, 3.5299e-01, 8.7121e-01, 2.2169e-01,
3.2304e-01, 4.5953e-01],
[6.5836e-01, 5.1175e-01, 8.6871e-01, 4.5975e-01, 1.2508e-01, 1.7028e-03,
9.0690e-01, 8.8574e-01],
[3.5108e-01, 5.7840e-01, 3.5660e-01, 6.0996e-01, 8.5555e-01, 9.8065e-01,
9.2788e-01, 5.2479e-01],
[5.0547e-01, 1.2533e-01, 6.8123e-01, 7.9148e-01, 1.8241e-01, 3.4911e-01,
5.3862e-01, 9.4046e-01],
[5.7937e-01, 1.7068e-01, 9.4632e-01, 9.6949e-01, 6.3670e-01, 1.3676e-02,
5.9223e-01, 5.5669e-01],
[7.8400e-01, 4.6354e-01, 1.2119e-01, 2.0251e-01, 7.7572e-01, 4.1827e-01,
7.7916e-02, 3.6172e-02],
[2.3403e-01, 8.6540e-01, 6.3252e-01, 5.5149e-01, 9.4414e-01, 1.7763e-01,
9.8256e-01, 2.8106e-01],
[9.1932e-01, 7.2594e-01, 3.3175e-01, 4.1232e-01, 2.0944e-02, 5.5008e-01,
5.5868e-02, 7.0453e-01],
[4.2333e-01, 6.1034e-01, 9.7918e-02, 5.3541e-01, 8.9438e-01, 5.1894e-01,
2.7014e-01, 8.4121e-01],
[7.6894e-01, 9.8417e-01, 3.6006e-02, 6.9352e-01, 9.9930e-01, 6.4228e-01,
1.0796e-01, 8.9721e-01],
[1.1709e-01, 7.1025e-01, 8.8967e-01, 7.7617e-01, 1.6500e-01, 1.0298e-01,
1.7371e-02, 3.1897e-01],
[3.2567e-01, 6.4787e-01, 8.5421e-01, 3.2240e-01, 3.4492e-01, 6.2676e-01,
9.5652e-01, 6.4353e-01],
[7.2224e-02, 9.0711e-01, 6.7680e-01, 6.0739e-01, 3.6785e-01, 6.0026e-01,
7.5810e-01, 4.9210e-01],
[9.3992e-01, 2.8583e-01, 7.2763e-01, 8.3596e-02, 4.9811e-01, 4.6526e-01,
1.0000e+00, 4.1724e-01],
[6.9953e-01, 1.8845e-02, 1.8409e-01, 1.2019e-01, 6.0995e-01, 2.6710e-01,
4.6303e-01, 9.9077e-01],
[6.1174e-01, 1.8429e-01, 4.7651e-01, 3.5097e-02, 3.0916e-01, 6.7818e-01,
7.8857e-01, 3.4344e-01],
[8.0940e-01, 5.2800e-01, 5.3702e-02, 2.2327e-01, 5.0710e-01, 8.8906e-01,
1.8645e-01, 6.5467e-02],
[1.6389e-01, 3.4808e-01, 8.7864e-01, 9.0040e-01, 7.6194e-01, 8.1474e-01,
3.9797e-01, 2.3921e-02],
[2.8817e-01, 9.0915e-01, 5.8108e-02, 7.0396e-01, 9.8072e-01, 5.9085e-01,
2.3925e-01, 2.1451e-01],
[4.6891e-02, 3.9678e-01, 5.4548e-01, 5.8319e-01, 5.4510e-01, 4.7280e-01,
7.5991e-01, 4.2617e-01],
[9.4596e-01, 4.5436e-01, 2.2434e-01, 9.5481e-02, 5.2161e-01, 5.0698e-01,
7.3851e-01, 7.4119e-01],
[6.8886e-01, 3.1496e-01, 2.4052e-01, 1.4472e-01, 7.4014e-01, 9.1198e-01,
8.5745e-02, 7.2205e-01],
[1.9097e-01, 5.4156e-01, 4.2437e-01, 2.9216e-01, 6.0484e-02, 3.8315e-01,
7.9647e-01, 9.7574e-01],
[6.4082e-01, 1.4896e-01, 6.2245e-01, 6.7494e-01, 6.5571e-01, 2.4260e-01,
5.1253e-01, 9.2360e-02],
[4.8955e-01, 7.9233e-01, 7.0892e-01, 4.1421e-01, 2.8887e-01, 6.6633e-01,
7.0106e-01, 5.1431e-01],
[8.2962e-01, 3.6099e-01, 7.4685e-01, 7.5146e-01, 2.7401e-01, 2.5723e-01,
3.8840e-01, 7.7901e-02],
[8.5670e-01, 8.9227e-01, 1.3448e-01, 3.7872e-01, 5.9236e-01, 3.5953e-01,
4.3035e-01, 1.1953e-01],
[1.0840e-01, 5.8739e-01, 1.0961e-01, 1.7194e-01, 6.2567e-01, 4.4722e-01,
4.3340e-02, 5.4334e-02],
[2.6080e-01, 9.9861e-01, 1.5731e-02, 9.3153e-01, 6.8244e-01, 1.0493e-01,
5.2511e-01, 8.7053e-01],
[7.1785e-01, 3.7152e-02, 3.1320e-01, 9.2990e-01, 4.8680e-01, 7.1755e-01,
8.3084e-01, 5.9040e-01],
[5.2699e-01, 2.3941e-01, 2.6953e-01, 1.3020e-01, 9.6239e-01, 4.9595e-01,
6.2471e-01, 6.2967e-01],
[1.9691e-01, 4.2492e-01, 4.1357e-01, 4.6050e-01, 8.4722e-02, 8.9788e-01,
2.9176e-01, 4.6251e-01],
[7.1225e-01, 2.1253e-01, 8.0583e-01, 9.5188e-01, 4.2713e-01, 7.9367e-01,
8.7093e-01, 1.1444e-01],
[5.9723e-01, 6.9493e-01, 5.7506e-01, 4.3572e-01, 3.6347e-01, 5.3068e-01,
4.9009e-01, 1.2828e-01],
[6.0578e-02, 8.8352e-01, 3.7554e-01, 3.4150e-01, 5.0811e-02, 4.9105e-01,
6.7012e-01, 7.6106e-01],
[6.7371e-01, 4.7007e-02, 7.6273e-01, 7.2500e-01, 2.4425e-01, 3.0353e-01,
8.2666e-01, 5.4335e-01],
[7.2532e-01, 3.0514e-01, 9.4142e-01, 2.0765e-01, 7.2902e-01, 6.1904e-01,
3.5093e-01, 2.6148e-01],
[6.5241e-01, 5.6402e-01, 2.6127e-01, 2.5278e-01, 6.0199e-01, 4.1290e-01,
1.2920e-01, 6.1563e-01],
[5.7028e-01, 6.4284e-01, 5.7267e-01, 2.8449e-01, 4.3748e-01, 3.2697e-01,
3.4111e-01, 9.6522e-01],
[5.5637e-01, 8.4756e-02, 5.2855e-01, 5.0228e-01, 8.3836e-01, 7.0306e-01,
7.2233e-01, 1.7509e-01],
[4.0941e-01, 2.5354e-01, 5.5944e-01, 6.4800e-01, 9.8592e-02, 5.8168e-01,
1.2327e-01, 1.4073e-02],
[3.4184e-01, 7.7475e-01, 7.5182e-01, 8.7103e-01, 3.3409e-02, 9.2904e-01,
6.5325e-02, 9.1166e-01],
[8.2039e-01, 1.8310e-01, 4.3724e-01, 5.7073e-01, 6.9978e-01, 3.4614e-02,
9.9374e-02, 4.4710e-01],
[1.7703e-01, 1.9554e-01, 4.5178e-01, 4.9424e-01, 2.5601e-01, 9.3218e-01,
5.9863e-01, 1.9728e-01],
[2.2122e-01, 3.2899e-02, 7.1020e-03, 2.6787e-01, 1.3514e-01, 8.2407e-01,
2.0010e-01, 7.2697e-01],
[3.1558e-01, 4.3547e-01, 1.9673e-01, 6.0938e-02, 2.8390e-01, 4.3289e-01,
2.8729e-01, 2.4406e-01],
[2.4113e-02, 8.3314e-01, 9.6758e-01, 6.4248e-01, 8.1615e-01, 3.1886e-02,
4.5262e-01, 6.8182e-01]])

Loading

0 comments on commit 4ac3c91

Please sign in to comment.