Skip to content

Commit

Permalink
create BB-Net driver
Browse files Browse the repository at this point in the history
  • Loading branch information
Samy Wu Fung committed Aug 11, 2023
1 parent 19476ab commit e91042a
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 26 deletions.
88 changes: 72 additions & 16 deletions main_driver_warcraft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Assume path is root directory

from src.models import ShortestPathNet, Cvx_ShortestPathNet, Pert_ShortestPathNet, BB_ShortestPathNet
from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net
from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net, BB_Warcraft_Net
import matplotlib.pyplot as plt
import time as time
from src.trainer import trainer_warcraft
Expand All @@ -15,7 +15,7 @@

## Some fixed hyperparameters
max_epochs = 100
init_lr = 1e-5 # initial learning rate. We're using a scheduler.
init_lr = 1e-6 # initial learning rate. We're using a scheduler.
torch.manual_seed(0)

# check that directory to save data exists
Expand Down Expand Up @@ -118,8 +118,64 @@
# ## Save Histories
# torch.save(state, './src/warcraft/results/'+'DYS_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')

# # ---------------------------------------------------------------
# # ------------------------ Train PertOpt ------------------------
# # ---------------------------------------------------------------

# ## Load data
# data_path = base_data_path + 'Warcraft_training_data'+str(grid_size)+'.pth'
# state = torch.load(data_path)

# ## Extract data from state
# train_dataset = state['train_dataset']
# val_dataset = state['val_dataset']
# test_dataset = state['test_dataset']


# m= state["m"]
# # A = state["A"].float()
# # b = state["b"].float()
# num_edges = state["num_edges"]
# edge_list = state["edge_list"]
# edge_list_torch = torch.tensor(edge_list)

# # A = A.to(device)
# # b = b.to(device)


# PertOpt_net = Pert_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device='cpu')
# PertOpt_net.to('cpu')

# # # Train
# print('\n-------------------------------------------- TRAINING PertOpt Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------')

# start_time = time.time()
# best_params_PertOpt, val_loss_hist_PertOpt, val_acc_hist_PertOpt, test_loss_PertOpt, test_acc_PertOpt, train_time_PertOpt = trainer_warcraft(PertOpt_net, train_dataset, val_dataset, test_dataset,
# grid_size, max_epochs, init_lr, edge_list,
# use_scheduler=False, device='cpu',
# train_batch_size=train_batch_size,
# test_batch_size=test_batch_size,
# graph_type='V')
# end_time = time.time()
# print('\n time to train PertOpt GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds')

# state = {
# 'val_loss_hist_PertOpt': val_loss_hist_PertOpt,
# 'val_acc_hist_PertOpt': val_acc_hist_PertOpt,
# 'test_loss_PertOpt': test_loss_PertOpt,
# 'test_acc_PertOpt': test_acc_PertOpt,
# 'train_time_PertOpt': train_time_PertOpt
# }

# # Save weights
# torch.save(best_params_PertOpt, './src/warcraft/saved_weights/'+'PertOpt_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')

# ## Save Histories
# torch.save(state, './src/warcraft/results/'+'PertOpt_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')


# ---------------------------------------------------------------
# ------------------------ Train PertOpt ------------------------
# ------------------------ Train BB ------------------------
# ---------------------------------------------------------------

## Load data
Expand All @@ -143,32 +199,32 @@
# b = b.to(device)


PertOpt_net = Pert_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device='cpu')
PertOpt_net.to('cpu')
BB_net = BB_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device=device)
BB_net = BB_net.to(device)

# # Train
print('\n-------------------------------------------- TRAINING PertOpt Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------')
print('\n-------------------------------------------- TRAINING Blackbox Backprop Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------')

start_time = time.time()
best_params_PertOpt, val_loss_hist_PertOpt, val_acc_hist_PertOpt, test_loss_PertOpt, test_acc_PertOpt, train_time_PertOpt = trainer_warcraft(PertOpt_net, train_dataset, val_dataset, test_dataset,
best_params_BB, val_loss_hist_BB, val_acc_hist_BB, test_loss_BB, test_acc_BB, train_time_BB = trainer_warcraft(BB_net, train_dataset, val_dataset, test_dataset,
grid_size, max_epochs, init_lr, edge_list,
use_scheduler=False, device='cpu',
use_scheduler=False, device=device,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
graph_type='V')
end_time = time.time()
print('\n time to train PertOpt GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds')
print('\n time to train BB GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds')

state = {
'val_loss_hist_PertOpt': val_loss_hist_PertOpt,
'val_acc_hist_PertOpt': val_acc_hist_PertOpt,
'test_loss_PertOpt': test_loss_PertOpt,
'test_acc_PertOpt': test_acc_PertOpt,
'train_time_PertOpt': train_time_PertOpt
'val_loss_hist_BB': val_loss_hist_BB,
'val_acc_hist_BB': val_acc_hist_BB,
'test_loss_BB': test_loss_BB,
'test_acc_BB': test_acc_BB,
'train_time_BB': train_time_BB
}

# Save weights
torch.save(best_params_PertOpt, './src/warcraft/saved_weights/'+'PertOpt_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')
torch.save(best_params_BB, './src/warcraft/saved_weights/'+'BB_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')

## Save Histories
torch.save(state, './src/warcraft/results/'+'PertOpt_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')
torch.save(state, './src/warcraft/results/'+'BB_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth')
32 changes: 31 additions & 1 deletion src/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import cvxpy as cp
# import blackbox_backprop as bb
import blackbox_backprop as bb
from cvxpylayers.torch import CvxpyLayer
from abc import ABC, abstractmethod
from src.dys_opt_net import DYS_opt_net
Expand Down Expand Up @@ -241,5 +241,35 @@ def forward(self, d):
else:
path = self.dijkstra(cost_vec.view(cost_vec.shape[0], self.m, self.m), batch_mode=True)
return path.to(self.device)

## Create NN using Blackbox backprop of Vlastelica et al
class BB_Warcraft_Net(nn.Module):
'''
This net is equipped to run an m-by-m grid graphs. No A matrix is necessary.
Not quite working. No signal is backpropagating?
'''
def __init__(self, edges, num_edges, m, device='cpu', in_channels=3):
super().__init__()
self.m = m
self.device = device
self.shortestPath = bb.ShortestPath()

self.resnet_model = torchvision.models.resnet18(pretrained=False, num_classes=num_edges)
del self.resnet_model.conv1
self.resnet_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.fc_final = nn.Linear(in_features=64*24*24, out_features=self.m**2)

def forward(self, d):
batch_size = d.shape[0]
d = self.resnet_model.conv1(d)
d = self.resnet_model.bn1(d)
d = self.resnet_model.relu(d)
d = self.resnet_model.maxpool(d)
d = self.resnet_model.layer1(d)
cost_vec = self.fc_final(d.reshape(batch_size, -1)).view(batch_size, -1) # size = batch_size x num_vertices
suggested_weights = cost_vec.view(cost_vec.shape[0], self.m, self.m)
suggested_shortest_paths = self.shortestPath.apply(suggested_weights, 100)

return suggested_shortest_paths


11 changes: 8 additions & 3 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,19 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset,
while epoch <= max_epochs:

# training step
for terrain_batch, path_batch_edge, _, _ in train_loader:
for terrain_batch, path_batch_edge, path_batch_vertex, _ in train_loader:

terrain_batch = terrain_batch.to(device)
path_batch_edge =path_batch_edge.to(device)
path_batch_edge = path_batch_edge.to(device)
path_batch_vertex = path_batch_vertex.to(device)
net.train()
optimizer.zero_grad()
path_pred = net(terrain_batch)
loss = criterion(path_pred, path_batch_edge)

if graph_type=='E':
loss = criterion(path_pred, path_batch_edge)
else:
loss = criterion(path_pred, path_batch_vertex)
train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item()
loss.backward()
optimizer.step()
Expand Down
23 changes: 17 additions & 6 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
import time as time
import unittest
import blackbox_backprop as bb

import sys
root_dir = "../"
Expand All @@ -14,7 +15,7 @@


from src.models import ShortestPathNet, Cvx_ShortestPathNet, Pert_ShortestPathNet, BB_ShortestPathNet
from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net
from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net, BB_Warcraft_Net
from src.trainer import trainer
from src.utils import edge_to_node, node_to_edge, compute_accuracy

Expand Down Expand Up @@ -47,10 +48,13 @@ def setUp(self):
self.A = A.to(self.device)
self.b = b.to(self.device)
self.n_samples = 10
self.d_edge, self.path_edge, self.path_vertex, self.costs = test_dataset[0:self.n_samples]
self.terrain, self.path_edge, self.path_vertex, self.costs = test_dataset[0:self.n_samples]

self.dys_net = DYS_Warcraft_Net(self.A, self.b, self.edge_list, self.num_edges, self.device)
self.dys_net.to(self.device)

self.bb_net = BB_Warcraft_Net(self.edge_list, self.num_edges, self.grid_size, device=self.device)
self.bb_net.to(self.device)

def test_edge_to_node(self):
# Test that edge_to_node returns correct node path
Expand Down Expand Up @@ -129,18 +133,25 @@ def test_compute_accuracy(self):

def test_dys_net(self):

path_pred = self.dys_net(self.d_edge).detach()
cost_pred = self.dys_net.data_space_forward(self.d_edge).detach()
path_pred = self.dys_net(self.terrain).detach()
cost_pred = self.dys_net.data_space_forward(self.terrain).detach()

self.assertTrue(torch.all(cost_pred >= 0))
print('cost_vec >= 0')

for i in range(path_pred.shape[0]):
constraint_norm = torch.norm(self.A@path_pred[i,:] - self.b)
self.assertTrue( constraint_norm <= 1e-2)
print('for sample ', i, ', |Ax - b| = ', constraint_norm, ' < 1e-2')
print('for sample ', i, ', |Ax - b| = ', constraint_norm, ' < 1e-1')
self.assertTrue( constraint_norm <= 1e-1)

print('\n\n-------------- dys_net tests passed --------------\n\n')

def test_bb_net(self):

path_pred = self.bb_net(self.terrain)
self.assertTrue(path_pred.shape == self.path_vertex.shape)

print('\n\n-------------- bb_net tests passed --------------\n\n')

if __name__ == '__main__':
unittest.main()

0 comments on commit e91042a

Please sign in to comment.