diff --git a/main_driver_warcraft.py b/main_driver_warcraft.py index b965fc7..65b72eb 100644 --- a/main_driver_warcraft.py +++ b/main_driver_warcraft.py @@ -1,7 +1,7 @@ # Assume path is root directory from src.models import ShortestPathNet, Cvx_ShortestPathNet, Pert_ShortestPathNet, BB_ShortestPathNet -from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net +from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net, BB_Warcraft_Net import matplotlib.pyplot as plt import time as time from src.trainer import trainer_warcraft @@ -15,7 +15,7 @@ ## Some fixed hyperparameters max_epochs = 100 -init_lr = 1e-5 # initial learning rate. We're using a scheduler. +init_lr = 1e-6 # initial learning rate. We're using a scheduler. torch.manual_seed(0) # check that directory to save data exists @@ -118,8 +118,64 @@ # ## Save Histories # torch.save(state, './src/warcraft/results/'+'DYS_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') +# # --------------------------------------------------------------- +# # ------------------------ Train PertOpt ------------------------ +# # --------------------------------------------------------------- + +# ## Load data +# data_path = base_data_path + 'Warcraft_training_data'+str(grid_size)+'.pth' +# state = torch.load(data_path) + +# ## Extract data from state +# train_dataset = state['train_dataset'] +# val_dataset = state['val_dataset'] +# test_dataset = state['test_dataset'] + + +# m= state["m"] +# # A = state["A"].float() +# # b = state["b"].float() +# num_edges = state["num_edges"] +# edge_list = state["edge_list"] +# edge_list_torch = torch.tensor(edge_list) + +# # A = A.to(device) +# # b = b.to(device) + + +# PertOpt_net = Pert_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device='cpu') +# PertOpt_net.to('cpu') + +# # # Train +# print('\n-------------------------------------------- TRAINING PertOpt Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------') + +# start_time = time.time() +# best_params_PertOpt, val_loss_hist_PertOpt, val_acc_hist_PertOpt, test_loss_PertOpt, test_acc_PertOpt, train_time_PertOpt = trainer_warcraft(PertOpt_net, train_dataset, val_dataset, test_dataset, +# grid_size, max_epochs, init_lr, edge_list, +# use_scheduler=False, device='cpu', +# train_batch_size=train_batch_size, +# test_batch_size=test_batch_size, +# graph_type='V') +# end_time = time.time() +# print('\n time to train PertOpt GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds') + +# state = { +# 'val_loss_hist_PertOpt': val_loss_hist_PertOpt, +# 'val_acc_hist_PertOpt': val_acc_hist_PertOpt, +# 'test_loss_PertOpt': test_loss_PertOpt, +# 'test_acc_PertOpt': test_acc_PertOpt, +# 'train_time_PertOpt': train_time_PertOpt +# } + +# # Save weights +# torch.save(best_params_PertOpt, './src/warcraft/saved_weights/'+'PertOpt_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') + +# ## Save Histories +# torch.save(state, './src/warcraft/results/'+'PertOpt_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') + + # --------------------------------------------------------------- -# ------------------------ Train PertOpt ------------------------ +# ------------------------ Train BB ------------------------ # --------------------------------------------------------------- ## Load data @@ -143,32 +199,32 @@ # b = b.to(device) -PertOpt_net = Pert_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device='cpu') -PertOpt_net.to('cpu') +BB_net = BB_Warcraft_Net(edges=edge_list, num_edges=num_edges, m=m, device=device) +BB_net = BB_net.to(device) # # Train -print('\n-------------------------------------------- TRAINING PertOpt Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------') +print('\n-------------------------------------------- TRAINING Blackbox Backprop Warcraft GRID ' + str(grid_size) + '-by-' + str(grid_size) + ' --------------------------------------------') start_time = time.time() -best_params_PertOpt, val_loss_hist_PertOpt, val_acc_hist_PertOpt, test_loss_PertOpt, test_acc_PertOpt, train_time_PertOpt = trainer_warcraft(PertOpt_net, train_dataset, val_dataset, test_dataset, +best_params_BB, val_loss_hist_BB, val_acc_hist_BB, test_loss_BB, test_acc_BB, train_time_BB = trainer_warcraft(BB_net, train_dataset, val_dataset, test_dataset, grid_size, max_epochs, init_lr, edge_list, - use_scheduler=False, device='cpu', + use_scheduler=False, device=device, train_batch_size=train_batch_size, test_batch_size=test_batch_size, graph_type='V') end_time = time.time() -print('\n time to train PertOpt GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds') +print('\n time to train BB GRID ' + str(grid_size) + '-by-' + str(grid_size), ' = ', end_time-start_time, ' seconds') state = { - 'val_loss_hist_PertOpt': val_loss_hist_PertOpt, - 'val_acc_hist_PertOpt': val_acc_hist_PertOpt, - 'test_loss_PertOpt': test_loss_PertOpt, - 'test_acc_PertOpt': test_acc_PertOpt, - 'train_time_PertOpt': train_time_PertOpt + 'val_loss_hist_BB': val_loss_hist_BB, + 'val_acc_hist_BB': val_acc_hist_BB, + 'test_loss_BB': test_loss_BB, + 'test_acc_BB': test_acc_BB, + 'train_time_BB': train_time_BB } # Save weights -torch.save(best_params_PertOpt, './src/warcraft/saved_weights/'+'PertOpt_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') +torch.save(best_params_BB, './src/warcraft/saved_weights/'+'BB_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') ## Save Histories -torch.save(state, './src/warcraft/results/'+'PertOpt_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') \ No newline at end of file +torch.save(state, './src/warcraft/results/'+'BB_results_'+str(grid_size) + '-by-' + str(grid_size) + '.pth') \ No newline at end of file diff --git a/src/models.py b/src/models.py index 1a5fc22..7674f97 100644 --- a/src/models.py +++ b/src/models.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import cvxpy as cp -# import blackbox_backprop as bb +import blackbox_backprop as bb from cvxpylayers.torch import CvxpyLayer from abc import ABC, abstractmethod from src.dys_opt_net import DYS_opt_net @@ -241,5 +241,35 @@ def forward(self, d): else: path = self.dijkstra(cost_vec.view(cost_vec.shape[0], self.m, self.m), batch_mode=True) return path.to(self.device) + +## Create NN using Blackbox backprop of Vlastelica et al +class BB_Warcraft_Net(nn.Module): + ''' + This net is equipped to run an m-by-m grid graphs. No A matrix is necessary. + Not quite working. No signal is backpropagating? + ''' + def __init__(self, edges, num_edges, m, device='cpu', in_channels=3): + super().__init__() + self.m = m + self.device = device + self.shortestPath = bb.ShortestPath() + + self.resnet_model = torchvision.models.resnet18(pretrained=False, num_classes=num_edges) + del self.resnet_model.conv1 + self.resnet_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.fc_final = nn.Linear(in_features=64*24*24, out_features=self.m**2) + + def forward(self, d): + batch_size = d.shape[0] + d = self.resnet_model.conv1(d) + d = self.resnet_model.bn1(d) + d = self.resnet_model.relu(d) + d = self.resnet_model.maxpool(d) + d = self.resnet_model.layer1(d) + cost_vec = self.fc_final(d.reshape(batch_size, -1)).view(batch_size, -1) # size = batch_size x num_vertices + suggested_weights = cost_vec.view(cost_vec.shape[0], self.m, self.m) + suggested_shortest_paths = self.shortestPath.apply(suggested_weights, 100) + + return suggested_shortest_paths diff --git a/src/trainer.py b/src/trainer.py index 32e33d9..d10a225 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -247,14 +247,19 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset, while epoch <= max_epochs: # training step - for terrain_batch, path_batch_edge, _, _ in train_loader: + for terrain_batch, path_batch_edge, path_batch_vertex, _ in train_loader: terrain_batch = terrain_batch.to(device) - path_batch_edge =path_batch_edge.to(device) + path_batch_edge = path_batch_edge.to(device) + path_batch_vertex = path_batch_vertex.to(device) net.train() optimizer.zero_grad() path_pred = net(terrain_batch) - loss = criterion(path_pred, path_batch_edge) + + if graph_type=='E': + loss = criterion(path_pred, path_batch_edge) + else: + loss = criterion(path_pred, path_batch_vertex) train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item() loss.backward() optimizer.step() diff --git a/tests/test_utilities.py b/tests/test_utilities.py index d9a6f5d..3f7a432 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import time as time import unittest +import blackbox_backprop as bb import sys root_dir = "../" @@ -14,7 +15,7 @@ from src.models import ShortestPathNet, Cvx_ShortestPathNet, Pert_ShortestPathNet, BB_ShortestPathNet -from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net +from src.models import DYS_Warcraft_Net, Pert_Warcraft_Net, BB_Warcraft_Net from src.trainer import trainer from src.utils import edge_to_node, node_to_edge, compute_accuracy @@ -47,10 +48,13 @@ def setUp(self): self.A = A.to(self.device) self.b = b.to(self.device) self.n_samples = 10 - self.d_edge, self.path_edge, self.path_vertex, self.costs = test_dataset[0:self.n_samples] + self.terrain, self.path_edge, self.path_vertex, self.costs = test_dataset[0:self.n_samples] self.dys_net = DYS_Warcraft_Net(self.A, self.b, self.edge_list, self.num_edges, self.device) self.dys_net.to(self.device) + + self.bb_net = BB_Warcraft_Net(self.edge_list, self.num_edges, self.grid_size, device=self.device) + self.bb_net.to(self.device) def test_edge_to_node(self): # Test that edge_to_node returns correct node path @@ -129,18 +133,25 @@ def test_compute_accuracy(self): def test_dys_net(self): - path_pred = self.dys_net(self.d_edge).detach() - cost_pred = self.dys_net.data_space_forward(self.d_edge).detach() + path_pred = self.dys_net(self.terrain).detach() + cost_pred = self.dys_net.data_space_forward(self.terrain).detach() self.assertTrue(torch.all(cost_pred >= 0)) print('cost_vec >= 0') for i in range(path_pred.shape[0]): constraint_norm = torch.norm(self.A@path_pred[i,:] - self.b) - self.assertTrue( constraint_norm <= 1e-2) - print('for sample ', i, ', |Ax - b| = ', constraint_norm, ' < 1e-2') + print('for sample ', i, ', |Ax - b| = ', constraint_norm, ' < 1e-1') + self.assertTrue( constraint_norm <= 1e-1) print('\n\n-------------- dys_net tests passed --------------\n\n') + def test_bb_net(self): + + path_pred = self.bb_net(self.terrain) + self.assertTrue(path_pred.shape == self.path_vertex.shape) + + print('\n\n-------------- bb_net tests passed --------------\n\n') + if __name__ == '__main__': unittest.main() \ No newline at end of file