Skip to content

Commit

Permalink
Refactor test ESWorker to separate class (#370)
Browse files Browse the repository at this point in the history
This doesn't technically remove a circular dependency, but it does
remove an awkward one.
  • Loading branch information
boomanaiden154 authored Sep 17, 2024
1 parent 9c81ac6 commit 47acf0c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 25 deletions.
4 changes: 2 additions & 2 deletions compiler_opt/es/blackbox_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from compiler_opt.distributed.local import local_worker_manager
from compiler_opt.rl import corpus
from compiler_opt.es import blackbox_learner_test
from compiler_opt.es import blackbox_test_utils
from compiler_opt.es import blackbox_evaluator


Expand All @@ -29,7 +29,7 @@ class BlackboxEvaluatorTests(absltest.TestCase):

def test_sampling_get_results(self):
with local_worker_manager.LocalWorkerPoolManager(
blackbox_learner_test.ESWorker, count=3, arg='', kwarg='') as pool:
blackbox_test_utils.ESWorker, count=3, arg='', kwarg='') as pool:
perturbations = [b'00', b'01', b'10']
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
# pylint: disable=protected-access
Expand Down
25 changes: 2 additions & 23 deletions compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
from absl.testing import absltest
import gin
import tempfile
from typing import List
import numpy as np
import numpy.typing as npt
import tensorflow as tf
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import actor_policy

from compiler_opt.distributed import worker
from compiler_opt.distributed.local import local_worker_manager
from compiler_opt.es import blackbox_learner
from compiler_opt.es import policy_utils
Expand All @@ -36,26 +34,7 @@
from compiler_opt.rl import registry
from compiler_opt.rl.inlining import config as inlining_config
from compiler_opt.es import blackbox_evaluator


@gin.configurable
class ESWorker(worker.Worker):
"""Temporary placeholder worker.
Each time a worker is called, the function value
it will return increases."""

def __init__(self, arg, *, kwarg):
self._arg = arg
self._kwarg = kwarg
self.function_value = 0.0

def compile(self, policy: policy_saver.Policy,
samples: List[corpus.ModuleSpec]) -> float:
if policy and samples:
self.function_value += 1.0
return self.function_value
else:
return 0.0
from compiler_opt.es import blackbox_test_utils


class BlackboxLearnerTests(absltest.TestCase):
Expand Down Expand Up @@ -165,7 +144,7 @@ def test_prune_skipped_perturbations(self):

def test_run_step(self):
with local_worker_manager.LocalWorkerPoolManager(
ESWorker, count=3, arg='', kwarg='') as pool:
blackbox_test_utils.ESWorker, count=3, arg='', kwarg='') as pool:
self._learner.run_step(pool) # pylint: disable=protected-access
# expected length calculated from expected shapes of variables
self.assertEqual(len(self._learner.get_model_weights()), 17154)
Expand Down
43 changes: 43 additions & 0 deletions compiler_opt/es/blackbox_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test facilities for Blackbox classes."""

from typing import List

import gin

from compiler_opt.distributed import worker
from compiler_opt.rl import corpus
from compiler_opt.rl import policy_saver


@gin.configurable
class ESWorker(worker.Worker):
"""Temporary placeholder worker.
Each time a worker is called, the function value
it will return increases."""

def __init__(self, arg, *, kwarg):
self._arg = arg
self._kwarg = kwarg
self.function_value = 0.0

def compile(self, policy: policy_saver.Policy,
samples: List[corpus.ModuleSpec]) -> float:
if policy and samples:
self.function_value += 1.0
return self.function_value
else:
return 0.0

0 comments on commit 47acf0c

Please sign in to comment.