Skip to content

Commit 7c60e29

Browse files
committed
apply code suggestion, fix black
Signed-off-by: luarss <[email protected]>
1 parent dde2995 commit 7c60e29

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tools/AutoTuner/src/autotuner/distributed.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,11 @@ def parse_arguments():
806806
help="Perturbation interval for PopulationBasedTraining.",
807807
)
808808
tune_parser.add_argument(
809-
"--seed", type=int, metavar="<int>", default=42, help="Random seed. (0 means no seed.)"
809+
"--seed",
810+
type=int,
811+
metavar="<int>",
812+
default=42,
813+
help="Random seed. (0 means no seed.)",
810814
)
811815

812816
# Workload
@@ -873,10 +877,16 @@ def set_algorithm(experiment_name, config):
873877
"""
874878
# Pre-set seed if user sets seed to 0
875879
if args.seed == 0:
876-
print("Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)")
880+
print(
881+
"Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
882+
)
877883
if input().lower() != "y":
878884
sys.exit(0)
879885
args.seed = None
886+
else:
887+
torch.manual_seed(args.seed)
888+
np.random.seed(args.seed)
889+
random.seed(args.seed)
880890

881891
if args.algorithm == "hyperopt":
882892
algorithm = HyperOptSearch(
@@ -896,10 +906,7 @@ def set_algorithm(experiment_name, config):
896906
)
897907
algorithm = AxSearch(ax_client=ax_client, points_to_evaluate=best_params)
898908
elif args.algorithm == "optuna":
899-
algorithm = OptunaSearch(
900-
points_to_evaluate=best_params,
901-
seed=args.seed
902-
)
909+
algorithm = OptunaSearch(points_to_evaluate=best_params, seed=args.seed)
903910
elif args.algorithm == "pbt":
904911
print("Warning: PBT does not support seed values. args.seed will be ignored.")
905912
algorithm = PopulationBasedTraining(
@@ -911,18 +918,13 @@ def set_algorithm(experiment_name, config):
911918
elif args.algorithm == "random":
912919
algorithm = BasicVariantGenerator(
913920
max_concurrent=args.jobs,
914-
random_state=args.seed,)
921+
random_state=args.seed,
922+
)
915923

916924
# A wrapper algorithm for limiting the number of concurrent trials.
917925
if args.algorithm not in ["random", "pbt"]:
918926
algorithm = ConcurrencyLimiter(algorithm, max_concurrent=args.jobs)
919927

920-
# Self seed
921-
if args.seed is not None:
922-
torch.manual_seed(args.seed)
923-
np.random.seed(args.seed)
924-
random.seed(args.seed)
925-
926928
return algorithm
927929

928930

0 commit comments

Comments
 (0)