-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from JulienT01/agents_restructuration
fix previous agents, and add envs
- Loading branch information
Showing
6 changed files
with
228 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .finite import FiniteMDP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .finite_mdp import FiniteMDP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import numpy as np | ||
|
||
|
||
import rlberry.spaces as spaces | ||
from rlberry.envs.interface import Model | ||
|
||
import rlberry | ||
|
||
logger = rlberry.logger | ||
|
||
|
||
class FiniteMDP(Model): | ||
""" | ||
Base class for a finite MDP. | ||
Terminal states are set to be absorbing, and | ||
are determined by the is_terminal() method, | ||
which can be overriden (and returns false by default). | ||
Parameters | ||
---------- | ||
R : numpy.ndarray | ||
P : numpy.ndarray | ||
initial_state_distribution : numpy.ndarray or int | ||
array of size (S,) containing the initial state distribution | ||
or an integer representing the initial/default state | ||
Attributes | ||
---------- | ||
R : numpy.ndarray | ||
array of shape (S, A) containing the mean rewards, where | ||
S = number of states; A = number of actions. | ||
P : numpy.ndarray | ||
array of shape (S, A, S) containing the transition probabilities, | ||
where P[s, a, s'] = Prob(S_{t+1}=s'| S_t = s, A_t = a). | ||
""" | ||
|
||
def __init__(self, R, P, initial_state_distribution=0): | ||
Model.__init__(self) | ||
self.initial_state_distribution = initial_state_distribution | ||
S, A = R.shape | ||
|
||
self.S = S | ||
self.A = A | ||
|
||
self.R = R | ||
self.P = P | ||
|
||
self.observation_space = spaces.Discrete(S) | ||
self.action_space = spaces.Discrete(A) | ||
self.reward_range = (self.R.min(), self.R.max()) | ||
|
||
self.state = None | ||
|
||
self._states = np.arange(S) | ||
self._actions = np.arange(A) | ||
|
||
self.reset() | ||
self._process_terminal_states() | ||
self._check() | ||
|
||
def reset(self, seed=None, options=None): | ||
""" | ||
Reset the environment to a default state. | ||
""" | ||
if isinstance(self.initial_state_distribution, np.ndarray): | ||
self.state = self.rng.choice( | ||
self._states, p=self.initial_state_distribution | ||
) | ||
else: | ||
self.state = self.initial_state_distribution | ||
return self.state, {} | ||
|
||
def _process_terminal_states(self): | ||
""" | ||
Adapt transition array P so that terminal states | ||
are absorbing. | ||
""" | ||
for ss in range(self.S): | ||
if self.is_terminal(ss): | ||
self.P[ss, :, :] = 0.0 | ||
self.P[ss, :, ss] = 1.0 | ||
|
||
def _check_init_distribution(self): | ||
if isinstance(self.initial_state_distribution, np.ndarray): | ||
assert abs(self.initial_state_distribution.sum() - 1.0) < 1e-15 | ||
else: | ||
assert self.initial_state_distribution >= 0 | ||
assert self.initial_state_distribution < self.S | ||
|
||
def _check(self): | ||
""" | ||
Check consistency of the MDP | ||
""" | ||
# Check initial_state_distribution | ||
self._check_init_distribution() | ||
|
||
# Check that P[s,a, :] is a probability distribution | ||
for s in self._states: | ||
for a in self._actions: | ||
assert abs(self.P[s, a, :].sum() - 1.0) < 1e-15 | ||
|
||
# Check that dimensions match | ||
S1, A1 = self.R.shape | ||
S2, A2, S3 = self.P.shape | ||
assert S1 == S2 == S3 | ||
assert A1 == A2 | ||
|
||
def set_initial_state_distribution(self, distribution): | ||
""" | ||
Parameters | ||
---------- | ||
distribution : numpy.ndarray or int | ||
array of size (S,) containing the initial state distribution | ||
or an integer representing the initial/default state | ||
""" | ||
self.initial_state_distribution = distribution | ||
self._check_init_distribution() | ||
|
||
def sample(self, state, action): | ||
""" | ||
Sample a transition s' from P(s'|state, action). | ||
""" | ||
prob = self.P[state, action, :] | ||
next_state = self.rng.choice(self._states, p=prob) | ||
reward = self.reward_fn(state, action, next_state) | ||
terminated = self.is_terminal(state) | ||
truncated = False | ||
info = {} | ||
return next_state, reward, terminated, truncated, info | ||
|
||
def step(self, action): | ||
assert action in self._actions, "Invalid action!" | ||
next_state, reward, terminated, truncated, info = self.sample( | ||
self.state, action | ||
) | ||
self.state = next_state | ||
return next_state, reward, terminated, truncated, info | ||
|
||
def is_terminal(self, state): | ||
""" | ||
Returns true if a state is terminal. | ||
""" | ||
return False | ||
|
||
def reward_fn(self, state, action, next_state): | ||
""" | ||
Reward function. Returns mean reward at (state, action) by default. | ||
Parameters | ||
---------- | ||
state : int | ||
current state | ||
action : int | ||
current action | ||
next_state : | ||
next state | ||
Returns: | ||
reward : float | ||
""" | ||
return self.R[state, action] | ||
|
||
def log(self): | ||
""" | ||
Print the structure of the MDP. | ||
""" | ||
indent = " " | ||
for s in self._states: | ||
logger.info(f"State {s} {indent}") | ||
for a in self._actions: | ||
logger.info(f"{indent} Action {a}") | ||
for ss in self._states: | ||
if self.P[s, a, ss] > 0.0: | ||
logger.info( | ||
f"{2 * indent} transition to {ss} " | ||
f"with prob {self.P[s, a, ss]: .2f}" | ||
) | ||
logger.info("~~~~~~~~~~~~~~~~~~~~") | ||
|
||
|
||
# if __name__ == '__main__': | ||
# S = 3 | ||
# A = 2 | ||
|
||
# R = np.random.uniform(0, 1, (S, A)) | ||
# P = np.random.uniform(0, 1, (S, A, S)) | ||
# initial_state_distr = 1 # np.ones(S)/S | ||
# for ss in range(S): | ||
# for aa in range(A): | ||
# P[ss, aa, :] /= P[ss, aa, :].sum() | ||
|
||
# env = FiniteMDP(R, P, initial_state_distribution=initial_state_distr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Tuple | ||
from copy import deepcopy | ||
from rlberry.seeding import safe_reseed | ||
|
||
|
||
import rlberry | ||
|
||
logger = rlberry.logger | ||
|
||
|
||
def process_env(env, seeder, copy_env=True): | ||
if isinstance(env, Tuple): | ||
constructor = env[0] | ||
if constructor is None: | ||
return None | ||
kwargs = env[1] or {} | ||
processed_env = constructor(**kwargs) | ||
else: | ||
if env is None: | ||
return None | ||
if copy_env: | ||
try: | ||
processed_env = deepcopy(env) | ||
except Exception as ex: | ||
raise RuntimeError("[Agent] Not possible to deepcopy env: " + str(ex)) | ||
else: | ||
processed_env = env | ||
reseeded = safe_reseed(processed_env, seeder) | ||
if not reseeded: | ||
logger.warning("[Agent] Not possible to reseed environment.") | ||
return processed_env |