Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

74 Policy evaluation and training cli (rllib) #85

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
68f7394
Add flatland training cli with ray.
chenkins Nov 1, 2024
f62da96
Fix error from ray passive_env_checker.
chenkins Nov 6, 2024
6f36350
Fix flattened gym observation builders.
chenkins Nov 6, 2024
63a65f3
Unit test for (gym) observation builder returned (d)types sizes/shapes.
chenkins Nov 6, 2024
7e682de
Extract ray examples.
chenkins Nov 11, 2024
6abc591
Update TODO.
chenkins Nov 16, 2024
3f903ca
Add ray address cli param.
chenkins Nov 19, 2024
3fe01dc
Cleanup training cli example.
chenkins Dec 6, 2024
e9d389b
Create test stub for env_creator.
chenkins Dec 6, 2024
ecf6deb
Mark training test as slow.
chenkins Dec 6, 2024
5c7a17e
Gym obs builder unit tests.
chenkins Dec 6, 2024
73d6bd2
Code cleanup gym obs builders.
chenkins Dec 6, 2024
37fe60f
Code cleanup gym obs builders.
chenkins Dec 6, 2024
113e719
Add option checkpointing.
chenkins Dec 7, 2024
b5e4681
Add rllib_demo.ipynb.
chenkins Dec 7, 2024
e2a758c
Remove obsolete ray shutdown.
chenkins Dec 9, 2024
952688d
Add rllib_demo.ipynb.
chenkins Dec 9, 2024
8b1f65a
Cleanup example cli interface.
chenkins Dec 14, 2024
a778e8c
Fix FlattenTreeObservation to contain all 12 features.
chenkins Dec 14, 2024
043c412
Add TreeObs TODOs.
chenkins Dec 14, 2024
ab29664
Update TODOs.
chenkins Dec 15, 2024
9acf958
Add regression test for tree obs.
chenkins Dec 20, 2024
ea222fa
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 20, 2024
e07eccc
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 21, 2024
2af24f0
Update TODOs.
chenkins Dec 22, 2024
f3547b7
Add documentation on feature groups in tree obs flattening.
chenkins Dec 22, 2024
058d349
Refactor tree obs normalization.
chenkins Dec 22, 2024
63de68c
Refactor tree obs normalization.
chenkins Dec 22, 2024
3cea4ce
Cleanup.
chenkins Dec 22, 2024
1c4f41b
Apply suggestions from code review
chenkins Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 66 additions & 68 deletions flatland/ml/observations/flatten_tree_observation_for_rail_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Adpated from https://github.com/aiAdrian/flatland_solver_policy/blob/main/observation/flatland/flatten_tree_observation_for_rail_env/flatten_tree_observation_for_rail_env_utils.py
chenkins marked this conversation as resolved.
Show resolved Hide resolved
Initially from https://github.com/instadeepai/Mava/blob/0.0.9/mava/wrappers/flatland.py
"""
from typing import Optional

import gymnasium as gym
Expand All @@ -8,72 +12,6 @@
from flatland.ml.observations.gym_observation_builder import GymObservationBuilder


# from https://github.com/aiAdrian/flatland_solver_policy/blob/main/observation/flatland/flatten_tree_observation_for_rail_env/flatten_tree_observation_for_rail_env_utils.py
# initially from https://github.com/instadeepai/Mava/blob/0.0.9/mava/wrappers/flatland.py
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
max = 0
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max


def min_gt(seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq) - 1
while idx >= 0:
if seq[idx] >= val and seq[idx] < min:
min = seq[idx]
idx -= 1
return min


# TODO documentataion, extract to library module or to normalized tree obs class?
def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
"""
This function returns the difference between min and max value of an observation.

Parameters
----------
obs
Observation that should be normalized
clip_min
min value where observation will be clipped
clip_max
max value where observation will be clipped
fixed_radius
normalize_to_range

Returns
-------
normalized and clipped observation
"""
if fixed_radius > 0:
max_obs = fixed_radius
else:
max_obs = max(1, max_lt(obs, 1000)) + 1

min_obs = 0
if normalize_to_range:
min_obs = min_gt(obs, 0)
if min_obs > max_obs:
min_obs = max_obs
if max_obs == min_obs:
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)


# TODO passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
class FlattenTreeObsForRailEnv(GymObservationBuilder[np.ndarray], TreeObsForRailEnv):
"""
Expand Down Expand Up @@ -210,13 +148,73 @@ def __init__(self, observation_radius: int = 2, **kwargs):
super().__init__(**kwargs)
self.observation_radius = observation_radius

def _max_lt(self, seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
max = 0
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max

def _min_gt(self, seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq) - 1
while idx >= 0:
if seq[idx] >= val and seq[idx] < min:
min = seq[idx]
idx -= 1
return min

def _norm_obs_clip(self, obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
"""
This function returns the difference between min and max value of an observation.

Parameters
----------
obs
Observation that should be normalized
clip_min
min value where observation will be clipped
clip_max
max value where observation will be clipped
fixed_radius
normalize_to_range

Returns
-------
normalized and clipped observation
"""
if fixed_radius > 0:
max_obs = fixed_radius
else:
max_obs = max(1, self._max_lt(obs, 1000)) + 1

min_obs = 0
if normalize_to_range:
min_obs = self._min_gt(obs, 0)
if min_obs > max_obs:
min_obs = max_obs
if max_obs == min_obs:
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)

def normalize_obs(self, obs):
data = obs[:self._len_data]
distance = obs[self._len_data:self._len_data + self._len_distance]
agent_data = obs[-self._len_agent_data:]

data = norm_obs_clip(data, fixed_radius=self.observation_radius)
distance = norm_obs_clip(distance, normalize_to_range=True)
data = self._norm_obs_clip(data, fixed_radius=self.observation_radius)
distance = self._norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1)
return np.concatenate((np.concatenate((data, distance)), agent_data))

Expand Down
Loading