Skip to content

Commit

Permalink
Add a whole corpus sampler
Browse files Browse the repository at this point in the history
This patch adds a whole corpus sampler. This is intended to be used for
trace data collection where we need to compile an entire corpus subset
that is extracted using separate tooling in order to generate a reward.
  • Loading branch information
boomanaiden154 committed Sep 12, 2024
1 parent 39369d3 commit 03ebc68
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
21 changes: 21 additions & 0 deletions compiler_opt/rl/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]:
return list(results)


class WholeCorpusSampler(Sampler):
"""Returns the entire corpus every time a sample is requested."""

# We need to include this method even if it is a useless super delegation
# as super().__init__ is an abstract method that we need to override to
# execute it without throwing an error.
def __init__(self, module_specs: Tuple[ModuleSpec]): # pylint: disable=useless-super-delegation
super().__init__(module_specs)

def reset(self):
pass

def __call__(self, k: int) -> List[ModuleSpec]:
"""Returns the entire corpus a list of module specs."""
if len(self._module_specs) != k:
raise ValueError(
f'The number of modules requested {k} is not equal to '
f'the number of modules in the corpus, {len(self._module_specs)}')
return list(self._module_specs)


class Corpus:
"""Represents a corpus.
Expand Down
28 changes: 28 additions & 0 deletions compiler_opt/rl/corpus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,34 @@ def test_sample_without_replacement(self):
self.assertEqual(samples[2].name, 'small')
self.assertEqual(samples[3].name, 'smol')

def test_whole_corpus_sampler(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[
corpus.ModuleSpec(name='xsmall', size=1),
corpus.ModuleSpec(name='small', size=5),
corpus.ModuleSpec(name='middle', size=20),
corpus.ModuleSpec(name='large', size=100)
],
sampler_type=corpus.WholeCorpusSampler)
sample = cps.sample(4, sort=True)
self.assertLen(sample, 4)
self.assertEqual(sample[0].name, 'large')
self.assertEqual(sample[1].name, 'middle')
self.assertEqual(sample[2].name, 'small')
self.assertEqual(sample[3].name, 'xsmall')

def test_whole_corpus_sampler_invalid_count(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[
corpus.ModuleSpec(name='small', size=1),
corpus.ModuleSpec(name='middle', size=2),
],
sampler_type=corpus.WholeCorpusSampler)
with self.assertRaises(ValueError):
cps.sample(1)

def test_filter(self):
cps = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
Expand Down

0 comments on commit 03ebc68

Please sign in to comment.