@@ -806,7 +806,11 @@ def parse_arguments():
806
806
help = "Perturbation interval for PopulationBasedTraining." ,
807
807
)
808
808
tune_parser .add_argument (
809
- "--seed" , type = int , metavar = "<int>" , default = 42 , help = "Random seed. (0 means no seed.)"
809
+ "--seed" ,
810
+ type = int ,
811
+ metavar = "<int>" ,
812
+ default = 42 ,
813
+ help = "Random seed. (0 means no seed.)" ,
810
814
)
811
815
812
816
# Workload
@@ -873,10 +877,16 @@ def set_algorithm(experiment_name, config):
873
877
"""
874
878
# Pre-set seed if user sets seed to 0
875
879
if args .seed == 0 :
876
- print ("Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)" )
880
+ print (
881
+ "Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
882
+ )
877
883
if input ().lower () != "y" :
878
884
sys .exit (0 )
879
885
args .seed = None
886
+ else :
887
+ torch .manual_seed (args .seed )
888
+ np .random .seed (args .seed )
889
+ random .seed (args .seed )
880
890
881
891
if args .algorithm == "hyperopt" :
882
892
algorithm = HyperOptSearch (
@@ -896,10 +906,7 @@ def set_algorithm(experiment_name, config):
896
906
)
897
907
algorithm = AxSearch (ax_client = ax_client , points_to_evaluate = best_params )
898
908
elif args .algorithm == "optuna" :
899
- algorithm = OptunaSearch (
900
- points_to_evaluate = best_params ,
901
- seed = args .seed
902
- )
909
+ algorithm = OptunaSearch (points_to_evaluate = best_params , seed = args .seed )
903
910
elif args .algorithm == "pbt" :
904
911
print ("Warning: PBT does not support seed values. args.seed will be ignored." )
905
912
algorithm = PopulationBasedTraining (
@@ -911,18 +918,13 @@ def set_algorithm(experiment_name, config):
911
918
elif args .algorithm == "random" :
912
919
algorithm = BasicVariantGenerator (
913
920
max_concurrent = args .jobs ,
914
- random_state = args .seed ,)
921
+ random_state = args .seed ,
922
+ )
915
923
916
924
# A wrapper algorithm for limiting the number of concurrent trials.
917
925
if args .algorithm not in ["random" , "pbt" ]:
918
926
algorithm = ConcurrencyLimiter (algorithm , max_concurrent = args .jobs )
919
927
920
- # Self seed
921
- if args .seed is not None :
922
- torch .manual_seed (args .seed )
923
- np .random .seed (args .seed )
924
- random .seed (args .seed )
925
-
926
928
return algorithm
927
929
928
930
0 commit comments