Skip to content

Commit

Permalink
Add info_spec kwarg to random_py_policy to allow generating info fields.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 320201608
Change-Id: I7e69f3ab9e558ae643fff7f34460356e736fc0bf
  • Loading branch information
Oscar Ramirez authored and copybara-github committed Jul 8, 2020
1 parent 18f12c7 commit 56a94d7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tf_agents/policies/random_py_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RandomPyPolicy(py_policy.PyPolicy):
def __init__(self,
time_step_spec: ts.TimeStep,
action_spec: types.NestedArraySpec,
info_spec: types.NestedArraySpec = (),
seed: Optional[types.Seed] = None,
outer_dims: Optional[Sequence[int]] = None,
observation_and_action_constraint_splitter: Optional[
Expand All @@ -50,6 +51,8 @@ def __init__(self,
given time_step when action is called.
action_spec: A nest of BoundedArraySpec representing the actions to sample
from.
info_spec: Nest of `tf.TypeSpec` representing the data in the policy
info field.
seed: Optional seed used to instantiate a random number generator.
outer_dims: An optional list/tuple specifying outer dimensions to add to
the spec shape before sampling. If unspecified the outer_dims are
Expand Down Expand Up @@ -99,6 +102,7 @@ def observation_and_action_constraint_splitter(observation):
super(RandomPyPolicy, self).__init__(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=info_spec,
observation_and_action_constraint_splitter=(
observation_and_action_constraint_splitter))

Expand Down Expand Up @@ -132,4 +136,7 @@ def _action(self, time_step, policy_state):
random_action = array_spec.sample_spec_nest(
self._action_spec, self._rng, outer_dims=outer_dims)

return policy_step.PolicyStep(random_action, policy_state)
info = array_spec.sample_spec_nest(
self._info_spec, self._rng, outer_dims=outer_dims)

return policy_step.PolicyStep(random_action, policy_state, info)

0 comments on commit 56a94d7

Please sign in to comment.