Skip to content

Commit

Permalink
fixed some args fields in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielMckenzie committed Jan 16, 2024
1 parent 05fc386 commit 97886dd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions shortest_path_experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Define the model types
model_types="DYS CVX PertOpt BBOpt"
grid_sizes="10 20 30 40 50"
grid_sizes="5 10 15 20 25 30"
reps="1 2 3"
data_dir="./src/shortest_path/shortest_path_data"
weights_dir="./src/shortest_path/saved_weights"
Expand All @@ -18,7 +18,7 @@ do
rep_data_dir="${data_dir}/$rep"
rep_weights_dir="${weights_dir}/$rep"
rep_results_dir="${results_dir}/$rep"
python -m src.shortest_path.train --model_type $model_type --grid_size $grid_size --num_data 1000 --data_dir $rep_data_dir --results_dir $rep_results_dir
python -m src.shortest_path.train --model_type $model_type --grid_size $grid_size --num_data 1000 --data_dir $rep_data_dir --results_dir $rep_results_dir --device cuda:0
echo "$grid_size"
done
done
Expand Down
2 changes: 2 additions & 0 deletions src/shortest_path/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def main(args):
parser = argparse.ArgumentParser(description="fpo-dys shortest path")
parser.add_argument('--grid_size', type=int, default=5)
parser.add_argument('--data_dir', type=str, default='./src/shortest_path/shortest_path_data/')
parser.add_argument('--data_deg', type=int, default=4)
parser.add_argument('--data_noise_width', type=float, default=0.5)
parser.add_argument('--weights_dir', type=str, default='./src/shortest_path/saved_weights/')
parser.add_argument('--results_dir', type=str, default='./src/shortest_path/results/')
parser.add_argument('--device', type=str, default='mps')
Expand Down

0 comments on commit 97886dd

Please sign in to comment.