-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_env.py
66 lines (54 loc) · 2.54 KB
/
simple_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gym
import numpy as np
from numpy.typing import NDArray
from typing import Any, Optional
# A very simple environment that gives -1 reward for every step, but 20 reward
# for a sequence of 5 identical steps.
class SimpleEnvV0(gym.Env[NDArray[np.int32], int]):
state: NDArray[np.int32]
def __init__(self, render_mode: None = None) -> None:
assert render_mode == None
self.repeats_needed = 5
self.reward_per_turn = -1
self.reward_on_success = 20
self.observation_space = gym.spaces.MultiDiscrete([2] * self.repeats_needed)
self.action_space = gym.spaces.Discrete(2)
def reset(self, seed: Optional[int] = None, options: Optional[dict[Any, Any]] = None) -> tuple[NDArray[np.int32], dict[Any, Any]]:
super().reset()
self.state = np.zeros(self.repeats_needed, dtype=np.int32)
self.done = False
return self.state, {}
def step(self, action: int) -> tuple[NDArray[np.int32], float, bool, bool, dict[Any, Any]]:
if action in [0, 1]:
self.state = np.concatenate((self.state[1:], np.array([action+1])))
else:
raise ValueError("Invalid action")
self.done = (self.state == 1).all()
reward = self.reward_on_success if self.done else self.reward_per_turn
return self.state, reward, self.done, False, {}
# A "copy me" env - +1 for successful copy, -1 for fail, -10 for if score
# drops below -50.
class SimpleEnvV1(gym.Env[int, int]):
def __init__(self, render_mode: None = None, size: int = 5, fail_threshold: int = -50) -> None:
assert render_mode == None
self.size = size
self.fail_threshold = fail_threshold
self.observation_space = gym.spaces.Discrete(self.size)
self.action_space = gym.spaces.Discrete(self.size)
def reset(self, seed: Optional[int] = None, options: Optional[dict[Any, Any]] = None) -> tuple[int, dict[Any, Any]]:
super().reset()
self.state = np.random.randint(self.size)
self.score = 0
self.done = False
return self.state, {}
def step(self, action: int) -> tuple[int, float, bool, bool, dict[Any, Any]]:
if action == self.state:
reward = 1
else:
reward = -1
self.score += reward
self.done = self.score < self.fail_threshold
info = { 'score': self.score }
return self.state, reward, self.done, False, info
gym.envs.registration.register(id='SimpleEnv-v0', entry_point=SimpleEnvV0)
gym.envs.registration.register(id='SimpleEnv-v1', entry_point=SimpleEnvV1)