diff --git a/src/traveling_salesman/generate_tsp_data.py b/src/traveling_salesman/generate_tsp_data.py new file mode 100644 index 0000000..f61f390 --- /dev/null +++ b/src/traveling_salesman/generate_tsp_data.py @@ -0,0 +1,75 @@ +from sklearn.model_selection import train_test_split +import pyepo +import torch +import dill +import argparse +import os + +# generate actual data +def main(args): + + # unpack args + num_nodes = args.num_nodes + num_data = args.num_data + num_feat = args.num_feat + num_item = args.num_item + + print('Generating training data for traveling salesman problem with '+str(num_nodes) + + ' nodes and ' + str(num_item) + ' items') + + weights_numpy, contexts_numpy, costs_numpy = pyepo.data.tsp.genData(num_data, num_feat, num_nodes, + deg=args.data_deg, noise_width=args.data_noise_width) + + # split train test data + d_train, d_test_val, w_train, w_test_val = train_test_split(contexts_numpy, costs_numpy, test_size=200) + d_test, d_val, w_test, w_val = train_test_split(d_test_val, w_test_val, test_size=100) + + # Define PyEPO model + caps = capacities.cpu() #[20] * 2 # capacity + optmodel = pyepo.model.grb.knapsackModel(weights_numpy, caps) + + # get optDataset + dataset_train = pyepo.data.dataset.optDataset(optmodel, d_train, w_train) + dataset_test = pyepo.data.dataset.optDataset(optmodel, d_test, w_test) + dataset_val = pyepo.data.dataset.optDataset(optmodel, d_val, w_val) + + # Remove the gurobi model befor esaving + dataset_train.model = None + dataset_test.model = None + dataset_val.model = None + + # Package into a dictionary + state = { 'weights_numpy' : weights_numpy, + 'contexts_numpy': contexts_numpy, + 'costs_numpy' : costs_numpy, + 'dataset_train' : dataset_train, + 'dataset_test' : dataset_test, + 'dataset_val' : dataset_val, + 'capacities' : capacities, + } + + # Save and finish up + print('Finished building dataset') + if not os.path.exists(args.data_dir): + os.makedirs(args.data_dir) + + state_path = args.data_dir + 'Knapsack_training_data_' + str(num_knapsack) + '_' + str(num_item) +'_data-deg_' + str(args.data_deg) +'.p' + dill.dump( state, open( state_path, "wb" ) ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='generate knapsack data') + parser.add_argument('--num_data', type=int, default=1000) + parser.add_argument('--num_feat', type=int, default=5) + parser.add_argument('--num_item', type=int, default=20) + parser.add_argument('-num_knapsack', type=int, default=2) + parser.add_argument('--data_deg', type=int, default=4) + parser.add_argument('--data_noise_width', type=float, default=0.5) + parser.add_argument('--data_dir', type=str, default='./src/knapsack/knapsack_data/') + args = parser.parse_args() + main(args) + + + + + +