Skip to content

Commit

Permalink
PUBLIC: Add an additional field to the quicksort probing, update it…
Browse files Browse the repository at this point in the history
…s specifications, and add a `track_max_steps` flag to the sampler. This flag will enable or disable length tracking for padding.

PiperOrigin-RevId: 638783182
  • Loading branch information
CLRSDev authored and copybara-github committed May 30, 2024
1 parent 5635129 commit 006e021
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 55 deletions.
12 changes: 8 additions & 4 deletions clrs/_src/algorithms/searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[j], A.shape[0]),
'i_rank': (i + 1) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[r], A.shape[0]),
},
)

tmp = A[i + 1]
A[i + 1] = A[r]
Expand All @@ -199,8 +201,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[r], A.shape[0]),
'i_rank': (i + 1 - p) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[i + 1], A.shape[0]),
},
)

return i + 1

Expand Down
60 changes: 42 additions & 18 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
):
"""Initializes a `Sampler`.
Expand All @@ -80,11 +81,18 @@ def __init__(
algorithm: The algorithm to sample from
spec: The algorithm spec.
num_samples: Number of algorithm unrolls to sample. If positive, all the
samples will be generated in the constructor, and at each call
of the `next` method a batch will be randomly selected among them.
If -1, samples are generated on the fly with each call to `next`.
samples will be generated in the constructor, and at each call of the
`next` method a batch will be randomly selected among them. If -1,
samples are generated on the fly with each call to `next`.
*args: Algorithm args.
seed: RNG seed.
track_max_steps: if True and sampling on the fly (`num_samples`==-1), we
keep track of the maximum unroll length so far to pad batched samples to
that length. This is desirable when batches are used in compiled
functions that need recompilation every time the batch size changes.
Also, we get an initial value for max_steps by generating 1000 samples,
which will slow down initialization. If uniform shape of the batches is
not a concern, set `track_max_steps` to False.
**kwargs: Algorithm kwargs.
"""

Expand All @@ -95,19 +103,21 @@ def __init__(
self._algorithm = algorithm
self._args = args
self._kwargs = kwargs
self._track_max_steps = track_max_steps

if num_samples < 0:
logging.warning('Sampling dataset on-the-fly, unlimited samples.')
# Just get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
if track_max_steps:
# Get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
else:
logging.info('Creating a dataset with %i samples.', num_samples)
(self._inputs, self._outputs, self._hints,
Expand Down Expand Up @@ -148,10 +158,16 @@ def next(self, batch_size: Optional[int] = None) -> Feedback:
"""
if batch_size:
if self._num_samples < 0: # generate on the fly
min_length = self.max_steps if self._track_max_steps else 0
inputs, outputs, hints, lengths = self._make_batch(
batch_size, self._spec, self.max_steps,
self._algorithm, *self._args, **self._kwargs)
if hints[0].data.shape[0] > self.max_steps:
batch_size,
self._spec,
min_length,
self._algorithm,
*self._args,
**self._kwargs,
)
if self._track_max_steps and hints[0].data.shape[0] > self.max_steps:
logging.warning('Increasing hint lengh from %i to %i',
self.max_steps, hints[0].data.shape[0])
self.max_steps = hints[0].data.shape[0]
Expand Down Expand Up @@ -261,6 +277,7 @@ def build_sampler(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
) -> Tuple[Sampler, specs.Spec]:
"""Builds a sampler. See `Sampler` documentation."""
Expand All @@ -276,8 +293,15 @@ def build_sampler(
if set(clean_kwargs) != set(kwargs):
logging.warning('Ignoring kwargs %s when building sampler class %s',
set(kwargs).difference(clean_kwargs), sampler_class)
sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
*args, **clean_kwargs)
sampler = sampler_class(
algorithm,
spec,
num_samples,
seed=seed,
track_max_steps=track_max_steps,
*args,
**clean_kwargs,
)
return sampler, spec


Expand Down
Loading

0 comments on commit 006e021

Please sign in to comment.