Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mines-opt-ml/fpo-dys into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielMckenzie committed Jan 18, 2024
2 parents f68545f + 800484c commit 3d13260
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 28 deletions.
154 changes: 127 additions & 27 deletions src/shortest_path/shortest_path_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,87 @@
import gurobipy as gp
from gurobipy import GRB
from pyepo.model.grb import optGrbModel
import numpy as np

class shortestPathModel(optGrbModel):
'''
Custom optModel class. Code written by Bo Tang as part of PyEPO package.
'''

def __init__(self, grid_size):
self.grid_size = grid_size
self.grid = (grid_size, grid_size)
self.arcs = self._getArcs()
class shortestPathModel_8(optGrbModel):
"""
This class is optimization model for shortest path problem on 2D grid with 8 neighbors
Attributes:
_model (GurobiPy model): Gurobi model
grid (tuple of int): Size of grid network
nodes (list): list of vertex
edges (list): List of arcs
nodes_map (ndarray): 2D array for node index
"""

def __init__(self, grid):
"""
Args:
grid (tuple of int): size of grid network
"""
self.grid = grid
self.nodes, self.edges, self.nodes_map = self._getEdges()
super().__init__()

def _getArcs(self):
def _getEdges(self):
"""
A helper method to get list of arcs for grid network
A method to get list of edges for grid network
Returns:
list: arcs
"""
arcs = []
# init list
nodes, edges = [], []
# init map from coord to ind
nodes_map = {}
for i in range(self.grid[0]):
# edges on rows
for j in range(self.grid[1] - 1):
v = i * self.grid[1] + j
arcs.append((v, v + 1))
# edges in columns
if i == self.grid[0] - 1:
continue
for j in range(self.grid[1]):
v = i * self.grid[1] + j
arcs.append((v, v + self.grid[1]))
return arcs
u = self._calNode(i, j)
nodes_map[u] = (i,j)
nodes.append(u)
# edge to 8 neighbors
# up
if i != 0:
v = self._calNode(i-1, j)
edges.append((u,v))
# up-right
if j != self.grid[1] - 1:
v = self._calNode(i-1, j+1)
edges.append((u,v))
# right
if j != self.grid[1] - 1:
v = self._calNode(i, j+1)
edges.append((u,v))
# down-right
if i != self.grid[0] - 1:
v = self._calNode(i+1, j+1)
edges.append((u,v))
# down
if i != self.grid[0] - 1:
v = self._calNode(i+1, j)
edges.append((u,v))
# down-left
if j != 0:
v = self._calNode(i+1, j-1)
edges.append((u,v))
# left
if j != 0:
v = self._calNode(i, j-1)
edges.append((u,v))
# top-left
if i != 0:
v = self._calNode(i-1, j-1)
edges.append((u,v))
return nodes, edges, nodes_map

def _calNode(self, x, y):
"""
A method to calculate index of node
"""
v = x * self.grid[1] + y
return v

def _getModel(self):
"""
Expand All @@ -41,18 +90,18 @@ def _getModel(self):
Returns:
tuple: optimization model and variables
"""
# create a model
# ceate a model
m = gp.Model("shortest path")
# varibles
x = m.addVars(self.arcs, name="x")
x = m.addVars(self.edges, ub=1, name="x")
# sense
m.modelSense = GRB.MINIMIZE
# flow conservation constraints
# constraints
for i in range(self.grid[0]):
for j in range(self.grid[1]):
v = i * self.grid[1] + j
v = self._calNode(i, j)
expr = 0
for e in self.arcs:
for e in self.edges:
# flow in
if v == e[1]:
expr += x[e]
Expand All @@ -68,4 +117,55 @@ def _getModel(self):
# transition
else:
m.addConstr(expr == 0)
return m, x
return m, x

def setObj(self, c):
"""
A method to set objective function
Args:
c (np.ndarray): cost of objective function
"""
# vector to matrix
c = c.reshape(self.grid)
# sum up vector cost
obj = c[0,0] + gp.quicksum(c[self.nodes_map[j]] * self.x[i,j] for i, j in self.x)
self._model.setObjective(obj)

def _convert_to_grid(self):
'''
Converts a path in edge form to grid form
'''
grid_form = np.zeros(self.grid)
grid_form[0,0] = 1.
grid_form[-1,-1] = 1.
for i, j in self.edges:
grid_form[self.nodes_map[i]] += 1.
grid_form[self.nodes_map[j]] += 1.
# reshape to vector?
grid_form = grid_form.reshape[-1]
return grid_form


def solve(self):
"""
A method to solve model
Returns:
tuple: optimal solution (list) and objective value (float)
"""
# update gurobi model
self._model.update()
# solve
self._model.optimize()
# kxk solution map
sol = np.zeros(self.grid)
for i, j in self.edges:
# active edge
if abs(1 - self.x[i,j].x) < 1e-3:
# node on active edge
sol[self.nodes_map[i]] = 1
sol[self.nodes_map[j]] = 1
# matrix to vector
sol = sol.reshape(-1)
return sol, self._model.objVal
43 changes: 42 additions & 1 deletion src/warcraft/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,45 @@
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from src.dys_opt_net import DYS_opt_net
from pyepo.model.grb import shortestPathModel
from pyepo.model.grb import shortestPathModel
from torchvision.models import resnet18
from src.shortest_path.shortest_path_utils import shortestPathModel_8

class WarcraftShortestPathNet(DYS_opt_net):
def __init__(self, grid_size, A, b, edges, context_size, device='mps'):
super(WarcraftShortestPathNet, self).__init__(A, b, device)
self.grid_size = grid_size
## These layers are like resnet18
resnet = resnet18(pretrained=False)
self.conv1 = resnet.conv1
self.bn = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.block = resnet.layer1
# now convert to 1 channel
self.conv2 = nn.Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
# max pooling
self.maxpool2 = nn.AdaptiveMaxPool2d((grid_size, grid_size))

## Optimization layer. Can be used within test_time_forward
self.shortest_path_solver = shortestPathModel_8((self.grid_size, self.grid_size))

def data_space_forward(self, d):
h = self.conv1(d)
h = self.bn(h)
h = self.relu(h)
h = self.maxpool1(h)
h = self.block(h)
h = self.conv2(h)
out = self.maxpool2(h)
# reshape for optmodel
out = torch.squeeze(out, 1)
cost_vec = out.reshape(out.shape[0], -1)
return cost_vec

def F(self, z, cost_vec):
return cost_vec + 0.0005*z

def test_time_forward(self, d):
return self.data_space_forward(d)

142 changes: 142 additions & 0 deletions src/warcraft/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import time as time
import torch.nn as nn
import pyepo
import os
from src.shortest_path.utils import edge_to_node
from src.utils.accuracy import accuracy

def trainer(net, train_dataset, test_dataset, val_dataset, edges, grid_size, max_time, max_epochs, learning_rate, model_type, weights_dir, device='mps'):

## Training setup
batch_size = 256
loader_train = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
loader_test = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
loader_val = DataLoader(dataset=val_dataset, batch_size=batch_size,
shuffle=False)

optimizer = optim.Adam(net.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min')

# Initialize loss and evaluation metric
if model_type == "DYS" or model_type == "CVX":
criterion = nn.MSELoss()
elif model_type == "BBOpt" or model_type == "PertOpt":
criterion = nn.L1Loss()

metric = pyepo.metric.regret

if model_type == "BBOpt":
dbb = pyepo.func.blackboxOpt(net.shortest_path_solver, lambd=5, processes=1)
elif model_type == "PertOpt":
ptb = pyepo.func.perturbedOpt(net.shortest_path_solver, n_samples=3, sigma=1.0, processes=2)
elif model_type == "DYS" or model_type == "CVX":
pass
else:
raise TypeError("Please choose a supported model!")

## Initialize arrays that will be returned and checkpoint directory
val_loss_hist= []
val_acc_hist = []
epoch_time_hist = []
checkpt_path = weights_dir + model_type + '/'
if not os.path.exists(checkpt_path):
os.makedirs(checkpt_path)

net.eval()
net.to('cpu')
best_val_loss = metric(net, net.shortest_path_solver, loader_val)

print('Initial validation loss is ', best_val_loss)
val_loss_hist.append(best_val_loss)
time_till_best_val_loss = 0

curr_val_acc = accuracy(net, net.shortest_path_solver, loader_val)
val_acc_hist.append(curr_val_acc)

## Compute initial test loss
best_test_loss = metric(net,net.shortest_path_solver, loader_test)
print('Initial test loss is ', best_test_loss)

## Train!
epoch=1
train_time=0
train_loss_ave = 0

while epoch <= max_epochs and train_time <= max_time:
start_time_epoch = time.time()
net.to(device)
# Iterate the training batch
for d_batch, w_batch, opt_sol, opt_value in loader_train:
d_batch = d_batch.to(device)
w_batch = w_batch.to(device)
opt_sol = opt_sol.to(device)
opt_value = opt_value.to(device)
net.train()
optimizer.zero_grad()
predicted = net(d_batch)
#print(edge_to_node(predicted[1,:], edges, grid_size, device))
#print(edge_to_node(opt_sol[1,:], edges, grid_size, device))
if model_type == "DYS" or model_type == "CVX":
loss = criterion(opt_sol, predicted)
elif model_type == "BBOpt":
x_predicted = dbb(predicted)
loss = criterion(opt_sol, x_predicted)
elif model_type == "PertOpt":
x_predicted = ptb(predicted)
loss = criterion(opt_sol, x_predicted)

loss.backward()
optimizer.step()
train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item()

end_time_epoch = time.time()
epoch_time = end_time_epoch - start_time_epoch
train_time += epoch_time
epoch_time_hist.append(epoch_time)

## Now compute loss on validation set
net.eval()
net.to('cpu')
val_loss = metric(net, net.shortest_path_solver, loader_val)
val_acc = accuracy(net, net.shortest_path_solver, loader_val)
val_acc_hist.append(val_acc)
print('\n Current validation accuracy is ' + str(val_acc))
scheduler.step(val_loss)

if val_loss < best_val_loss:
state_save_name = checkpt_path+'best.pth'
torch.save(net.state_dict(), state_save_name)
# If we have achieved lowest validation thus far, this will be the model selected.
# So, we compute test loss
best_test_loss = metric(net,net.shortest_path_solver, loader_test)
best_val_loss = val_loss
print('Best validation regret achieved at epoch ' + str(epoch))
time_till_best_val_loss = sum(epoch_time_hist)

# scheduler.step(val_loss)
val_loss_hist.append(val_loss)

print('epoch: ', epoch, 'validation regret is ', val_loss, 'epoch time: ', epoch_time)
epoch += 1


state_save_name = checkpt_path+'last.pth'
torch.save(net.state_dict(), state_save_name)
if time_till_best_val_loss < 1e-6:
time_till_best_val_loss = sum(epoch_time_hist)

# Collect results
results = {"val_loss_hist": val_loss_hist,
"val_acc_hist": val_acc_hist,
"epoch_time_hist": epoch_time_hist,
"best_test_loss": best_test_loss,
"time_till_best_val_loss":time_till_best_val_loss
}

return results

0 comments on commit 3d13260

Please sign in to comment.