Skip to content

Commit

Permalink
Merge pull request #7 from JulienT01/agents_restructuration
Browse files Browse the repository at this point in the history
fix previous agents, and add envs
  • Loading branch information
JulienT01 authored Nov 6, 2023
2 parents 8917f1d + 10069ea commit c5f1dda
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 2 deletions.
2 changes: 1 addition & 1 deletion rlberry_scool/agents/dynprog/value_iteration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rlberry.agents.agent import AgentWithSimplePolicy
from rlberry_scool.agents.dynprog.utils import backward_induction, value_iteration
from rlberry.envs.finite.finite_mdp import FiniteMDP
from rlberry_scool.envs.finite.finite_mdp import FiniteMDP


class ValueIterationAgent(AgentWithSimplePolicy):
Expand Down
2 changes: 1 addition & 1 deletion rlberry_scool/agents/ucbvi/ucbvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
update_value_and_get_action,
update_value_and_get_action_sd,
)
from rlberry.exploration_tools.discrete_counter import DiscreteCounter
from rlberry_research.exploration_tools.discrete_counter import DiscreteCounter
from rlberry_scool.agents.dynprog.utils import (
backward_induction_sd,
backward_induction_reward_sd,
Expand Down
1 change: 1 addition & 0 deletions rlberry_scool/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .finite import FiniteMDP
1 change: 1 addition & 0 deletions rlberry_scool/envs/finite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .finite_mdp import FiniteMDP
193 changes: 193 additions & 0 deletions rlberry_scool/envs/finite/finite_mdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import numpy as np


import rlberry.spaces as spaces
from rlberry.envs.interface import Model

import rlberry

logger = rlberry.logger


class FiniteMDP(Model):
"""
Base class for a finite MDP.
Terminal states are set to be absorbing, and
are determined by the is_terminal() method,
which can be overriden (and returns false by default).
Parameters
----------
R : numpy.ndarray
P : numpy.ndarray
initial_state_distribution : numpy.ndarray or int
array of size (S,) containing the initial state distribution
or an integer representing the initial/default state
Attributes
----------
R : numpy.ndarray
array of shape (S, A) containing the mean rewards, where
S = number of states; A = number of actions.
P : numpy.ndarray
array of shape (S, A, S) containing the transition probabilities,
where P[s, a, s'] = Prob(S_{t+1}=s'| S_t = s, A_t = a).
"""

def __init__(self, R, P, initial_state_distribution=0):
Model.__init__(self)
self.initial_state_distribution = initial_state_distribution
S, A = R.shape

self.S = S
self.A = A

self.R = R
self.P = P

self.observation_space = spaces.Discrete(S)
self.action_space = spaces.Discrete(A)
self.reward_range = (self.R.min(), self.R.max())

self.state = None

self._states = np.arange(S)
self._actions = np.arange(A)

self.reset()
self._process_terminal_states()
self._check()

def reset(self, seed=None, options=None):
"""
Reset the environment to a default state.
"""
if isinstance(self.initial_state_distribution, np.ndarray):
self.state = self.rng.choice(
self._states, p=self.initial_state_distribution
)
else:
self.state = self.initial_state_distribution
return self.state, {}

def _process_terminal_states(self):
"""
Adapt transition array P so that terminal states
are absorbing.
"""
for ss in range(self.S):
if self.is_terminal(ss):
self.P[ss, :, :] = 0.0
self.P[ss, :, ss] = 1.0

def _check_init_distribution(self):
if isinstance(self.initial_state_distribution, np.ndarray):
assert abs(self.initial_state_distribution.sum() - 1.0) < 1e-15
else:
assert self.initial_state_distribution >= 0
assert self.initial_state_distribution < self.S

def _check(self):
"""
Check consistency of the MDP
"""
# Check initial_state_distribution
self._check_init_distribution()

# Check that P[s,a, :] is a probability distribution
for s in self._states:
for a in self._actions:
assert abs(self.P[s, a, :].sum() - 1.0) < 1e-15

# Check that dimensions match
S1, A1 = self.R.shape
S2, A2, S3 = self.P.shape
assert S1 == S2 == S3
assert A1 == A2

def set_initial_state_distribution(self, distribution):
"""
Parameters
----------
distribution : numpy.ndarray or int
array of size (S,) containing the initial state distribution
or an integer representing the initial/default state
"""
self.initial_state_distribution = distribution
self._check_init_distribution()

def sample(self, state, action):
"""
Sample a transition s' from P(s'|state, action).
"""
prob = self.P[state, action, :]
next_state = self.rng.choice(self._states, p=prob)
reward = self.reward_fn(state, action, next_state)
terminated = self.is_terminal(state)
truncated = False
info = {}
return next_state, reward, terminated, truncated, info

def step(self, action):
assert action in self._actions, "Invalid action!"
next_state, reward, terminated, truncated, info = self.sample(
self.state, action
)
self.state = next_state
return next_state, reward, terminated, truncated, info

def is_terminal(self, state):
"""
Returns true if a state is terminal.
"""
return False

def reward_fn(self, state, action, next_state):
"""
Reward function. Returns mean reward at (state, action) by default.
Parameters
----------
state : int
current state
action : int
current action
next_state :
next state
Returns:
reward : float
"""
return self.R[state, action]

def log(self):
"""
Print the structure of the MDP.
"""
indent = " "
for s in self._states:
logger.info(f"State {s} {indent}")
for a in self._actions:
logger.info(f"{indent} Action {a}")
for ss in self._states:
if self.P[s, a, ss] > 0.0:
logger.info(
f"{2 * indent} transition to {ss} "
f"with prob {self.P[s, a, ss]: .2f}"
)
logger.info("~~~~~~~~~~~~~~~~~~~~")


# if __name__ == '__main__':
# S = 3
# A = 2

# R = np.random.uniform(0, 1, (S, A))
# P = np.random.uniform(0, 1, (S, A, S))
# initial_state_distr = 1 # np.ones(S)/S
# for ss in range(S):
# for aa in range(A):
# P[ss, aa, :] /= P[ss, aa, :].sum()

# env = FiniteMDP(R, P, initial_state_distribution=initial_state_distr)
31 changes: 31 additions & 0 deletions rlberry_scool/envs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Tuple
from copy import deepcopy
from rlberry.seeding import safe_reseed


import rlberry

logger = rlberry.logger


def process_env(env, seeder, copy_env=True):
if isinstance(env, Tuple):
constructor = env[0]
if constructor is None:
return None
kwargs = env[1] or {}
processed_env = constructor(**kwargs)
else:
if env is None:
return None
if copy_env:
try:
processed_env = deepcopy(env)
except Exception as ex:
raise RuntimeError("[Agent] Not possible to deepcopy env: " + str(ex))
else:
processed_env = env
reseeded = safe_reseed(processed_env, seeder)
if not reseeded:
logger.warning("[Agent] Not possible to reseed environment.")
return processed_env

0 comments on commit c5f1dda

Please sign in to comment.