Skip to content

Commit

Permalink
Started working on traveling_salesman
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielMckenzie committed Jun 14, 2024
1 parent 7a2fb70 commit 4250875
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/traveling_salesman/generate_tsp_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from sklearn.model_selection import train_test_split
import pyepo
import torch
import dill
import argparse
import os

# generate actual data
def main(args):

# unpack args
num_nodes = args.num_nodes
num_data = args.num_data
num_feat = args.num_feat
num_item = args.num_item

print('Generating training data for traveling salesman problem with '+str(num_nodes)
+ ' nodes and ' + str(num_item) + ' items')

weights_numpy, contexts_numpy, costs_numpy = pyepo.data.tsp.genData(num_data, num_feat, num_nodes,
deg=args.data_deg, noise_width=args.data_noise_width)

# split train test data
d_train, d_test_val, w_train, w_test_val = train_test_split(contexts_numpy, costs_numpy, test_size=200)
d_test, d_val, w_test, w_val = train_test_split(d_test_val, w_test_val, test_size=100)

# Define PyEPO model
caps = capacities.cpu() #[20] * 2 # capacity
optmodel = pyepo.model.grb.knapsackModel(weights_numpy, caps)

# get optDataset
dataset_train = pyepo.data.dataset.optDataset(optmodel, d_train, w_train)
dataset_test = pyepo.data.dataset.optDataset(optmodel, d_test, w_test)
dataset_val = pyepo.data.dataset.optDataset(optmodel, d_val, w_val)

# Remove the gurobi model befor esaving
dataset_train.model = None
dataset_test.model = None
dataset_val.model = None

# Package into a dictionary
state = { 'weights_numpy' : weights_numpy,
'contexts_numpy': contexts_numpy,
'costs_numpy' : costs_numpy,
'dataset_train' : dataset_train,
'dataset_test' : dataset_test,
'dataset_val' : dataset_val,
'capacities' : capacities,
}

# Save and finish up
print('Finished building dataset')
if not os.path.exists(args.data_dir):
os.makedirs(args.data_dir)

state_path = args.data_dir + 'Knapsack_training_data_' + str(num_knapsack) + '_' + str(num_item) +'_data-deg_' + str(args.data_deg) +'.p'
dill.dump( state, open( state_path, "wb" ) )

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='generate knapsack data')
parser.add_argument('--num_data', type=int, default=1000)
parser.add_argument('--num_feat', type=int, default=5)
parser.add_argument('--num_item', type=int, default=20)
parser.add_argument('-num_knapsack', type=int, default=2)
parser.add_argument('--data_deg', type=int, default=4)
parser.add_argument('--data_noise_width', type=float, default=0.5)
parser.add_argument('--data_dir', type=str, default='./src/knapsack/knapsack_data/')
args = parser.parse_args()
main(args)






0 comments on commit 4250875

Please sign in to comment.