From 523094708f18754f456310ae98584437d237b89c Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Mon, 16 Sep 2024 15:10:56 -0700 Subject: [PATCH] Serialize entire policy in blackbox_learner (#366) This patch adjusts blackbox_learner so that it returns an entire policy rather than just the bytes of the policy. When actually running evaluations, we need to writ out the full policy, including the output spec, to disk so the compiler can pick it up. Before this patch, we were not passing along the output spec to the worker. --- compiler_opt/es/blackbox_evaluator.py | 9 +++++---- compiler_opt/es/blackbox_learner.py | 18 +++++++++--------- compiler_opt/es/blackbox_learner_test.py | 3 ++- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 2bd68c51..787ae61d 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -25,6 +25,7 @@ from compiler_opt.rl import corpus from compiler_opt.es import blackbox_optimizers from compiler_opt.distributed import buffered_scheduler +from compiler_opt.rl import policy_saver class BlackboxEvaluator(metaclass=abc.ABCMeta): @@ -36,8 +37,8 @@ def __init__(self, train_corpus: corpus.Corpus): @abc.abstractmethod def get_results( - self, pool: FixedWorkerPool, - perturbations: List[bytes]) -> List[concurrent.futures.Future]: + self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy] + ) -> List[concurrent.futures.Future]: raise NotImplementedError() @abc.abstractmethod @@ -66,8 +67,8 @@ def __init__(self, train_corpus: corpus.Corpus, super().__init__(train_corpus) def get_results( - self, pool: FixedWorkerPool, - perturbations: List[bytes]) -> List[concurrent.futures.Future]: + self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy] + ) -> List[concurrent.futures.Future]: if not self._samples: for _ in range(self._total_num_perturbations): sample = self._train_corpus.sample(self._num_ir_repeats_within_worker) diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py index 984ef5a3..c4e82b62 100644 --- a/compiler_opt/es/blackbox_learner.py +++ b/compiler_opt/es/blackbox_learner.py @@ -223,8 +223,10 @@ def _save_model(self) -> None: def get_model_weights(self) -> npt.NDArray[np.float32]: return self._model_weights - def _get_policy_as_bytes(self, - perturbation: npt.NDArray[np.float32]) -> bytes: + # TODO: The current conversion is inefficient (performance-wise). We should + # consider doing this on the worker side. + def _get_policy_from_perturbation( + self, perturbation: npt.NDArray[np.float32]) -> policy_saver.Policy: sm = tf.saved_model.load(self._tf_policy_path) # devectorize the perturbation policy_utils.set_vectorized_parameters_for_policy(sm, perturbation) @@ -242,7 +244,7 @@ def _get_policy_as_bytes(self, # create and return policy policy_obj = policy_saver.Policy.from_filesystem(tfl_dir) - return policy_obj.policy + return policy_obj def run_step(self, pool: FixedWorkerPool) -> None: """Run a single step of blackbox learning. @@ -258,14 +260,12 @@ def run_step(self, pool: FixedWorkerPool) -> None: p for p in initial_perturbations for p in (p, -p) ] - # convert to bytes for compile job - # TODO: current conversion is inefficient. - # consider doing this on the worker side - perturbations_as_bytes = [] + perturbations_as_policies = [] for perturbation in initial_perturbations: - perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation)) + perturbations_as_policies.append( + self._get_policy_from_perturbation(perturbation)) - results = self._evaluator.get_results(pool, perturbations_as_bytes) + results = self._evaluator.get_results(pool, perturbations_as_policies) rewards = self._evaluator.get_rewards(results) num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards) diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py index cb5667c5..5f74a13a 100644 --- a/compiler_opt/es/blackbox_learner_test.py +++ b/compiler_opt/es/blackbox_learner_test.py @@ -45,7 +45,8 @@ def __init__(self, arg, *, kwarg): self._kwarg = kwarg self.function_value = 0.0 - def compile(self, policy: bytes, samples: List[corpus.ModuleSpec]) -> float: + def compile(self, policy: policy_saver.Policy, + samples: List[corpus.ModuleSpec]) -> float: if policy and samples: self.function_value += 1.0 return self.function_value