Skip to content

Commit

Permalink
update state and state test
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Nov 29, 2023
1 parent 2052d16 commit 8e88660
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
38 changes: 19 additions & 19 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,39 +158,39 @@ def update(self, key, value):
return self.__class__(x, device=self.device)

@classmethod
def from_gym(cls, state, device='cpu', dtype=np.float32):
def from_gym(cls, gym_output, device='cpu', dtype=np.float32):
"""
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
Args:
state (tuple): The return value of an OpenAI gym reset()/step(action) call
gym_output (tuple): The output value of an OpenAI gym reset()/step(action) call
device (string): The device on which to store resulting tensors.
dtype: The type of the observation.
Returns:
A State object.
"""
if not isinstance(state, tuple):
return State({
'observation': torch.from_numpy(
if not isinstance(gym_output, tuple) and (len(gym_output) == 2 or len(gym_output) == 5):
raise TypeError(f"gym_output should be a tuple, either (observation, info) or (observation, reward, terminated, truncated, info). Recieved {gym_output}.")

# extract info from timestep
if len(gym_output) == 5:
observation, reward, terminated, truncated, info = gym_output
if len(gym_output) == 2:
observation, info = gym_output
reward = 0.
terminated = False
truncated = False
x = {
'observation': torch.from_numpy(
np.array(
state,
observation,
dtype=dtype
),
).to(device)
}, device=device)

observation, reward, done, info = state
observation = torch.from_numpy(
np.array(
observation,
dtype=dtype
),
).to(device)
x = {
'observation': observation,
).to(device),
'reward': float(reward),
'done': done,
'done': terminated or truncated,
'mask': 1. - terminated
}
info = info if info else {}
for key in info:
Expand Down
19 changes: 15 additions & 4 deletions all/core/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,34 @@ def test_auto_mask_false(self):

def test_from_gym_reset(self):
observation = np.array([1, 2, 3])
state = State.from_gym(observation)
state = State.from_gym((observation, {'coolInfo': 3}))
tt.assert_equal(state.observation, torch.from_numpy(observation))
self.assertEqual(state.mask, 1.)
self.assertEqual(state.done, False)
self.assertEqual(state.reward, 0.)
self.assertEqual(state.shape, ())
self.assertEqual(state['coolInfo'], 3.)

def test_from_gym_step(self):
observation = np.array([1, 2, 3])
state = State.from_gym((observation, 2., True, {'coolInfo': 3.}))
state = State.from_gym((observation, 2., True, False, {'coolInfo': 3.}))
tt.assert_equal(state.observation, torch.from_numpy(observation))
self.assertEqual(state.mask, 0.)
self.assertEqual(state.done, True)
self.assertEqual(state.reward, 2.)
self.assertEqual(state['coolInfo'], 3.)
self.assertEqual(state.shape, ())

def test_from_truncated_gym_step(self):
observation = np.array([1, 2, 3])
state = State.from_gym((observation, 2., False, True, {'coolInfo': 3.}))
tt.assert_equal(state.observation, torch.from_numpy(observation))
self.assertEqual(state.mask, 1.)
self.assertEqual(state.done, True)
self.assertEqual(state.reward, 2.)
self.assertEqual(state['coolInfo'], 3.)
self.assertEqual(state.shape, ())

def test_as_input(self):
observation = torch.randn(3, 4)
state = State(observation)
Expand All @@ -79,7 +90,7 @@ def test_as_output(self):

def test_apply_mask(self):
observation = torch.randn(3, 4)
state = State.from_gym((observation, 0., True, {}))
state = State.from_gym((observation, 0., True, False, {}))
tt.assert_equal(state.apply_mask(observation), torch.zeros(3, 4))

def test_apply(self):
Expand All @@ -92,7 +103,7 @@ def test_apply(self):

def test_apply_done(self):
observation = torch.randn(3, 4)
state = State.from_gym((observation, 0., True, {}))
state = State.from_gym((observation, 0., True, False, {}))
model = torch.nn.Conv1d(3, 5, 2)
output = state.apply(model, 'observation')
self.assertEqual(output.shape, (5, 3))
Expand Down

0 comments on commit 8e88660

Please sign in to comment.