40
40
from uuid import uuid4 as uuid
41
41
42
42
import numpy as np
43
+ import torch
43
44
44
45
import ray
45
46
from ray import tune
@@ -805,7 +806,7 @@ def parse_arguments():
805
806
help = "Perturbation interval for PopulationBasedTraining." ,
806
807
)
807
808
tune_parser .add_argument (
808
- "--seed" , type = int , metavar = "<int>" , default = 42 , help = "Random seed."
809
+ "--seed" , type = int , metavar = "<int>" , default = 42 , help = "Random seed. (0 means no seed.) "
809
810
)
810
811
811
812
# Workload
@@ -870,10 +871,23 @@ def set_algorithm(experiment_name, config):
870
871
"""
871
872
Configure search algorithm.
872
873
"""
874
+ # Pre-set seed if user sets seed to 0
875
+ if args .seed == 0 :
876
+ print ("Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)" )
877
+ if input ().lower () != "y" :
878
+ sys .exit (0 )
879
+ args .seed = None
880
+
873
881
if args .algorithm == "hyperopt" :
874
- algorithm = HyperOptSearch (points_to_evaluate = best_params )
882
+ algorithm = HyperOptSearch (
883
+ points_to_evaluate = best_params ,
884
+ random_state_seed = args .seed ,
885
+ )
875
886
elif args .algorithm == "ax" :
876
- ax_client = AxClient (enforce_sequential_optimization = False )
887
+ ax_client = AxClient (
888
+ enforce_sequential_optimization = False ,
889
+ random_seed = args .seed ,
890
+ )
877
891
AxClientMetric = namedtuple ("AxClientMetric" , "minimize" )
878
892
ax_client .create_experiment (
879
893
name = experiment_name ,
@@ -882,18 +896,33 @@ def set_algorithm(experiment_name, config):
882
896
)
883
897
algorithm = AxSearch (ax_client = ax_client , points_to_evaluate = best_params )
884
898
elif args .algorithm == "optuna" :
885
- algorithm = OptunaSearch (points_to_evaluate = best_params , seed = args .seed )
899
+ algorithm = OptunaSearch (
900
+ points_to_evaluate = best_params ,
901
+ seed = args .seed
902
+ )
886
903
elif args .algorithm == "pbt" :
904
+ print ("Warning: PBT does not support seed values. args.seed will be ignored." )
887
905
algorithm = PopulationBasedTraining (
888
906
time_attr = "training_iteration" ,
889
907
perturbation_interval = args .perturbation ,
890
908
hyperparam_mutations = config ,
891
909
synch = True ,
892
910
)
893
911
elif args .algorithm == "random" :
894
- algorithm = BasicVariantGenerator (max_concurrent = args .jobs )
912
+ algorithm = BasicVariantGenerator (
913
+ max_concurrent = args .jobs ,
914
+ random_state = args .seed ,)
915
+
916
+ # A wrapper algorithm for limiting the number of concurrent trials.
895
917
if args .algorithm not in ["random" , "pbt" ]:
896
918
algorithm = ConcurrencyLimiter (algorithm , max_concurrent = args .jobs )
919
+
920
+ # Self seed
921
+ if args .seed is not None :
922
+ torch .manual_seed (args .seed )
923
+ np .random .seed (args .seed )
924
+ random .seed (args .seed )
925
+
897
926
return algorithm
898
927
899
928
0 commit comments