Skip to content

Commit dde2995

Browse files
committed
initial scaffold
Signed-off-by: Jack Luar <[email protected]>
1 parent 9cf53e9 commit dde2995

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

tools/AutoTuner/src/autotuner/distributed.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from uuid import uuid4 as uuid
4141

4242
import numpy as np
43+
import torch
4344

4445
import ray
4546
from ray import tune
@@ -805,7 +806,7 @@ def parse_arguments():
805806
help="Perturbation interval for PopulationBasedTraining.",
806807
)
807808
tune_parser.add_argument(
808-
"--seed", type=int, metavar="<int>", default=42, help="Random seed."
809+
"--seed", type=int, metavar="<int>", default=42, help="Random seed. (0 means no seed.)"
809810
)
810811

811812
# Workload
@@ -870,10 +871,23 @@ def set_algorithm(experiment_name, config):
870871
"""
871872
Configure search algorithm.
872873
"""
874+
# Pre-set seed if user sets seed to 0
875+
if args.seed == 0:
876+
print("Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)")
877+
if input().lower() != "y":
878+
sys.exit(0)
879+
args.seed = None
880+
873881
if args.algorithm == "hyperopt":
874-
algorithm = HyperOptSearch(points_to_evaluate=best_params)
882+
algorithm = HyperOptSearch(
883+
points_to_evaluate=best_params,
884+
random_state_seed=args.seed,
885+
)
875886
elif args.algorithm == "ax":
876-
ax_client = AxClient(enforce_sequential_optimization=False)
887+
ax_client = AxClient(
888+
enforce_sequential_optimization=False,
889+
random_seed=args.seed,
890+
)
877891
AxClientMetric = namedtuple("AxClientMetric", "minimize")
878892
ax_client.create_experiment(
879893
name=experiment_name,
@@ -882,18 +896,33 @@ def set_algorithm(experiment_name, config):
882896
)
883897
algorithm = AxSearch(ax_client=ax_client, points_to_evaluate=best_params)
884898
elif args.algorithm == "optuna":
885-
algorithm = OptunaSearch(points_to_evaluate=best_params, seed=args.seed)
899+
algorithm = OptunaSearch(
900+
points_to_evaluate=best_params,
901+
seed=args.seed
902+
)
886903
elif args.algorithm == "pbt":
904+
print("Warning: PBT does not support seed values. args.seed will be ignored.")
887905
algorithm = PopulationBasedTraining(
888906
time_attr="training_iteration",
889907
perturbation_interval=args.perturbation,
890908
hyperparam_mutations=config,
891909
synch=True,
892910
)
893911
elif args.algorithm == "random":
894-
algorithm = BasicVariantGenerator(max_concurrent=args.jobs)
912+
algorithm = BasicVariantGenerator(
913+
max_concurrent=args.jobs,
914+
random_state=args.seed,)
915+
916+
# A wrapper algorithm for limiting the number of concurrent trials.
895917
if args.algorithm not in ["random", "pbt"]:
896918
algorithm = ConcurrencyLimiter(algorithm, max_concurrent=args.jobs)
919+
920+
# Self seed
921+
if args.seed is not None:
922+
torch.manual_seed(args.seed)
923+
np.random.seed(args.seed)
924+
random.seed(args.seed)
925+
897926
return algorithm
898927

899928

0 commit comments

Comments
 (0)