diff --git a/src/shortest_path/shortest_path_utils.py b/src/shortest_path/shortest_path_utils.py index 6659d3d..6c21831 100644 --- a/src/shortest_path/shortest_path_utils.py +++ b/src/shortest_path/shortest_path_utils.py @@ -1,38 +1,87 @@ import gurobipy as gp from gurobipy import GRB from pyepo.model.grb import optGrbModel +import numpy as np -class shortestPathModel(optGrbModel): - ''' - Custom optModel class. Code written by Bo Tang as part of PyEPO package. - ''' - def __init__(self, grid_size): - self.grid_size = grid_size - self.grid = (grid_size, grid_size) - self.arcs = self._getArcs() +class shortestPathModel_8(optGrbModel): + """ + This class is optimization model for shortest path problem on 2D grid with 8 neighbors + + Attributes: + _model (GurobiPy model): Gurobi model + grid (tuple of int): Size of grid network + nodes (list): list of vertex + edges (list): List of arcs + nodes_map (ndarray): 2D array for node index + """ + + def __init__(self, grid): + """ + Args: + grid (tuple of int): size of grid network + """ + self.grid = grid + self.nodes, self.edges, self.nodes_map = self._getEdges() super().__init__() - def _getArcs(self): + def _getEdges(self): """ - A helper method to get list of arcs for grid network + A method to get list of edges for grid network Returns: list: arcs """ - arcs = [] + # init list + nodes, edges = [], [] + # init map from coord to ind + nodes_map = {} for i in range(self.grid[0]): - # edges on rows - for j in range(self.grid[1] - 1): - v = i * self.grid[1] + j - arcs.append((v, v + 1)) - # edges in columns - if i == self.grid[0] - 1: - continue for j in range(self.grid[1]): - v = i * self.grid[1] + j - arcs.append((v, v + self.grid[1])) - return arcs + u = self._calNode(i, j) + nodes_map[u] = (i,j) + nodes.append(u) + # edge to 8 neighbors + # up + if i != 0: + v = self._calNode(i-1, j) + edges.append((u,v)) + # up-right + if j != self.grid[1] - 1: + v = self._calNode(i-1, j+1) + edges.append((u,v)) + # right + if j != self.grid[1] - 1: + v = self._calNode(i, j+1) + edges.append((u,v)) + # down-right + if i != self.grid[0] - 1: + v = self._calNode(i+1, j+1) + edges.append((u,v)) + # down + if i != self.grid[0] - 1: + v = self._calNode(i+1, j) + edges.append((u,v)) + # down-left + if j != 0: + v = self._calNode(i+1, j-1) + edges.append((u,v)) + # left + if j != 0: + v = self._calNode(i, j-1) + edges.append((u,v)) + # top-left + if i != 0: + v = self._calNode(i-1, j-1) + edges.append((u,v)) + return nodes, edges, nodes_map + + def _calNode(self, x, y): + """ + A method to calculate index of node + """ + v = x * self.grid[1] + y + return v def _getModel(self): """ @@ -41,18 +90,18 @@ def _getModel(self): Returns: tuple: optimization model and variables """ - # create a model + # ceate a model m = gp.Model("shortest path") # varibles - x = m.addVars(self.arcs, name="x") + x = m.addVars(self.edges, ub=1, name="x") # sense m.modelSense = GRB.MINIMIZE - # flow conservation constraints + # constraints for i in range(self.grid[0]): for j in range(self.grid[1]): - v = i * self.grid[1] + j + v = self._calNode(i, j) expr = 0 - for e in self.arcs: + for e in self.edges: # flow in if v == e[1]: expr += x[e] @@ -68,4 +117,55 @@ def _getModel(self): # transition else: m.addConstr(expr == 0) - return m, x \ No newline at end of file + return m, x + + def setObj(self, c): + """ + A method to set objective function + + Args: + c (np.ndarray): cost of objective function + """ + # vector to matrix + c = c.reshape(self.grid) + # sum up vector cost + obj = c[0,0] + gp.quicksum(c[self.nodes_map[j]] * self.x[i,j] for i, j in self.x) + self._model.setObjective(obj) + + def _convert_to_grid(self): + ''' + Converts a path in edge form to grid form + ''' + grid_form = np.zeros(self.grid) + grid_form[0,0] = 1. + grid_form[-1,-1] = 1. + for i, j in self.edges: + grid_form[self.nodes_map[i]] += 1. + grid_form[self.nodes_map[j]] += 1. + # reshape to vector? + grid_form = grid_form.reshape[-1] + return grid_form + + + def solve(self): + """ + A method to solve model + + Returns: + tuple: optimal solution (list) and objective value (float) + """ + # update gurobi model + self._model.update() + # solve + self._model.optimize() + # kxk solution map + sol = np.zeros(self.grid) + for i, j in self.edges: + # active edge + if abs(1 - self.x[i,j].x) < 1e-3: + # node on active edge + sol[self.nodes_map[i]] = 1 + sol[self.nodes_map[j]] = 1 + # matrix to vector + sol = sol.reshape(-1) + return sol, self._model.objVal \ No newline at end of file diff --git a/src/warcraft/models.py b/src/warcraft/models.py index 5eb0ca4..7ced122 100644 --- a/src/warcraft/models.py +++ b/src/warcraft/models.py @@ -3,4 +3,45 @@ import cvxpy as cp from cvxpylayers.torch import CvxpyLayer from src.dys_opt_net import DYS_opt_net -from pyepo.model.grb import shortestPathModel \ No newline at end of file +from pyepo.model.grb import shortestPathModel +from torchvision.models import resnet18 +from src.shortest_path.shortest_path_utils import shortestPathModel_8 + +class WarcraftShortestPathNet(DYS_opt_net): + def __init__(self, grid_size, A, b, edges, context_size, device='mps'): + super(WarcraftShortestPathNet, self).__init__(A, b, device) + self.grid_size = grid_size + ## These layers are like resnet18 + resnet = resnet18(pretrained=False) + self.conv1 = resnet.conv1 + self.bn = resnet.bn1 + self.relu = resnet.relu + self.maxpool = resnet.maxpool + self.block = resnet.layer1 + # now convert to 1 channel + self.conv2 = nn.Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False) + # max pooling + self.maxpool2 = nn.AdaptiveMaxPool2d((grid_size, grid_size)) + + ## Optimization layer. Can be used within test_time_forward + self.shortest_path_solver = shortestPathModel_8((self.grid_size, self.grid_size)) + + def data_space_forward(self, d): + h = self.conv1(d) + h = self.bn(h) + h = self.relu(h) + h = self.maxpool1(h) + h = self.block(h) + h = self.conv2(h) + out = self.maxpool2(h) + # reshape for optmodel + out = torch.squeeze(out, 1) + cost_vec = out.reshape(out.shape[0], -1) + return cost_vec + + def F(self, z, cost_vec): + return cost_vec + 0.0005*z + + def test_time_forward(self, d): + return self.data_space_forward(d) + diff --git a/src/warcraft/trainer.py b/src/warcraft/trainer.py new file mode 100644 index 0000000..57c00cb --- /dev/null +++ b/src/warcraft/trainer.py @@ -0,0 +1,142 @@ +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +import time as time +import torch.nn as nn +import pyepo +import os +from src.shortest_path.utils import edge_to_node +from src.utils.accuracy import accuracy + +def trainer(net, train_dataset, test_dataset, val_dataset, edges, grid_size, max_time, max_epochs, learning_rate, model_type, weights_dir, device='mps'): + + ## Training setup + batch_size = 256 + loader_train = DataLoader(dataset=train_dataset, batch_size=batch_size, + shuffle=True) + loader_test = DataLoader(dataset=test_dataset, batch_size=batch_size, + shuffle=False) + loader_val = DataLoader(dataset=val_dataset, batch_size=batch_size, + shuffle=False) + + optimizer = optim.Adam(net.parameters(), lr=learning_rate) + scheduler = ReduceLROnPlateau(optimizer, 'min') + + # Initialize loss and evaluation metric + if model_type == "DYS" or model_type == "CVX": + criterion = nn.MSELoss() + elif model_type == "BBOpt" or model_type == "PertOpt": + criterion = nn.L1Loss() + + metric = pyepo.metric.regret + + if model_type == "BBOpt": + dbb = pyepo.func.blackboxOpt(net.shortest_path_solver, lambd=5, processes=1) + elif model_type == "PertOpt": + ptb = pyepo.func.perturbedOpt(net.shortest_path_solver, n_samples=3, sigma=1.0, processes=2) + elif model_type == "DYS" or model_type == "CVX": + pass + else: + raise TypeError("Please choose a supported model!") + + ## Initialize arrays that will be returned and checkpoint directory + val_loss_hist= [] + val_acc_hist = [] + epoch_time_hist = [] + checkpt_path = weights_dir + model_type + '/' + if not os.path.exists(checkpt_path): + os.makedirs(checkpt_path) + + net.eval() + net.to('cpu') + best_val_loss = metric(net, net.shortest_path_solver, loader_val) + + print('Initial validation loss is ', best_val_loss) + val_loss_hist.append(best_val_loss) + time_till_best_val_loss = 0 + + curr_val_acc = accuracy(net, net.shortest_path_solver, loader_val) + val_acc_hist.append(curr_val_acc) + + ## Compute initial test loss + best_test_loss = metric(net,net.shortest_path_solver, loader_test) + print('Initial test loss is ', best_test_loss) + + ## Train! + epoch=1 + train_time=0 + train_loss_ave = 0 + + while epoch <= max_epochs and train_time <= max_time: + start_time_epoch = time.time() + net.to(device) + # Iterate the training batch + for d_batch, w_batch, opt_sol, opt_value in loader_train: + d_batch = d_batch.to(device) + w_batch = w_batch.to(device) + opt_sol = opt_sol.to(device) + opt_value = opt_value.to(device) + net.train() + optimizer.zero_grad() + predicted = net(d_batch) + #print(edge_to_node(predicted[1,:], edges, grid_size, device)) + #print(edge_to_node(opt_sol[1,:], edges, grid_size, device)) + if model_type == "DYS" or model_type == "CVX": + loss = criterion(opt_sol, predicted) + elif model_type == "BBOpt": + x_predicted = dbb(predicted) + loss = criterion(opt_sol, x_predicted) + elif model_type == "PertOpt": + x_predicted = ptb(predicted) + loss = criterion(opt_sol, x_predicted) + + loss.backward() + optimizer.step() + train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item() + + end_time_epoch = time.time() + epoch_time = end_time_epoch - start_time_epoch + train_time += epoch_time + epoch_time_hist.append(epoch_time) + + ## Now compute loss on validation set + net.eval() + net.to('cpu') + val_loss = metric(net, net.shortest_path_solver, loader_val) + val_acc = accuracy(net, net.shortest_path_solver, loader_val) + val_acc_hist.append(val_acc) + print('\n Current validation accuracy is ' + str(val_acc)) + scheduler.step(val_loss) + + if val_loss < best_val_loss: + state_save_name = checkpt_path+'best.pth' + torch.save(net.state_dict(), state_save_name) + # If we have achieved lowest validation thus far, this will be the model selected. + # So, we compute test loss + best_test_loss = metric(net,net.shortest_path_solver, loader_test) + best_val_loss = val_loss + print('Best validation regret achieved at epoch ' + str(epoch)) + time_till_best_val_loss = sum(epoch_time_hist) + + # scheduler.step(val_loss) + val_loss_hist.append(val_loss) + + print('epoch: ', epoch, 'validation regret is ', val_loss, 'epoch time: ', epoch_time) + epoch += 1 + + + state_save_name = checkpt_path+'last.pth' + torch.save(net.state_dict(), state_save_name) + if time_till_best_val_loss < 1e-6: + time_till_best_val_loss = sum(epoch_time_hist) + + # Collect results + results = {"val_loss_hist": val_loss_hist, + "val_acc_hist": val_acc_hist, + "epoch_time_hist": epoch_time_hist, + "best_test_loss": best_test_loss, + "time_till_best_val_loss":time_till_best_val_loss + } + + return results \ No newline at end of file