Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

74 Policy evaluation and training cli (rllib) #85

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
68f7394
Add flatland training cli with ray.
chenkins Nov 1, 2024
f62da96
Fix error from ray passive_env_checker.
chenkins Nov 6, 2024
6f36350
Fix flattened gym observation builders.
chenkins Nov 6, 2024
63a65f3
Unit test for (gym) observation builder returned (d)types sizes/shapes.
chenkins Nov 6, 2024
7e682de
Extract ray examples.
chenkins Nov 11, 2024
6abc591
Update TODO.
chenkins Nov 16, 2024
3f903ca
Add ray address cli param.
chenkins Nov 19, 2024
3fe01dc
Cleanup training cli example.
chenkins Dec 6, 2024
e9d389b
Create test stub for env_creator.
chenkins Dec 6, 2024
ecf6deb
Mark training test as slow.
chenkins Dec 6, 2024
5c7a17e
Gym obs builder unit tests.
chenkins Dec 6, 2024
73d6bd2
Code cleanup gym obs builders.
chenkins Dec 6, 2024
37fe60f
Code cleanup gym obs builders.
chenkins Dec 6, 2024
113e719
Add option checkpointing.
chenkins Dec 7, 2024
b5e4681
Add rllib_demo.ipynb.
chenkins Dec 7, 2024
e2a758c
Remove obsolete ray shutdown.
chenkins Dec 9, 2024
952688d
Add rllib_demo.ipynb.
chenkins Dec 9, 2024
8b1f65a
Cleanup example cli interface.
chenkins Dec 14, 2024
a778e8c
Fix FlattenTreeObservation to contain all 12 features.
chenkins Dec 14, 2024
043c412
Add TreeObs TODOs.
chenkins Dec 14, 2024
ab29664
Update TODOs.
chenkins Dec 15, 2024
9acf958
Add regression test for tree obs.
chenkins Dec 20, 2024
ea222fa
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 20, 2024
e07eccc
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 21, 2024
2af24f0
Update TODOs.
chenkins Dec 22, 2024
f3547b7
Add documentation on feature groups in tree obs flattening.
chenkins Dec 22, 2024
058d349
Refactor tree obs normalization.
chenkins Dec 22, 2024
63de68c
Refactor tree obs normalization.
chenkins Dec 22, 2024
3cea4ce
Cleanup.
chenkins Dec 22, 2024
1c4f41b
Apply suggestions from code review
chenkins Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions flatland/core/env_observation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ObservationBuilder:
"""

def __init__(self):
self.env = None
self.env: Environment = None

def set_env(self, env: Environment):
self.env: Environment = env
Expand Down Expand Up @@ -91,8 +91,5 @@ def __init__(self):
def reset(self):
pass

def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True

def get(self, handle: int = 0) -> bool:
return True
Empty file.
47 changes: 47 additions & 0 deletions flatland/env_generation/env_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import ParamMalfunctionGen, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator


# TODO test this is determinist - is seeding done correctly?
# defaults from Flatland 3 Round 2 Test_0, see https://flatland.aicrowd.com/challenges/flatland3/envconfig.html
def env_creator(n_agents=7,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use env_generator to use same naming as for rail and line generators?

x_dim=30,
y_dim=30,
n_cities=2,
max_rail_pairs_in_city=4,
grid_mode=False,
max_rails_between_cities=2,
malfunction_duration_min=20,
malfunction_duration_max=50,
malfunction_interval=540,
speed_ratios=None,
seed=42,
obs_builder_object=None) -> RailEnv:
if speed_ratios is None:
speed_ratios = {1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25}
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))

env = RailEnv(
width=x_dim,
height=y_dim,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=grid_mode,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_pairs_in_city
),
malfunction_generator=ParamMalfunctionGen(MalfunctionParameters(
min_duration=malfunction_duration_min, max_duration=malfunction_duration_max, malfunction_rate=1.0 / malfunction_interval)),
line_generator=sparse_line_generator(speed_ratio_map=speed_ratios, seed=seed),
number_of_agents=n_agents,
obs_builder_object=obs_builder_object,
record_steps=True
)
env.reset(random_seed=seed)
return env
17 changes: 1 addition & 16 deletions flatland/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Collection of environment-specific ObservationBuilder.
"""
import collections
from typing import Optional, List, Dict, Tuple
from typing import Optional, List, Dict

import numpy as np

Expand Down Expand Up @@ -211,7 +211,6 @@ def get(self, handle: int = 0) -> Node:
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()

# was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
Expand Down Expand Up @@ -440,7 +439,6 @@ def _explore_branch(self, handle, position, direction, tot_dist, depth):
dist_to_next_branch = tot_dist
dist_min_to_target = distance_map_handle[position[0], position[1], direction]

# TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
Expand Down Expand Up @@ -580,10 +578,6 @@ def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1

# TODO can we do this more elegantly?
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:, :, 4] = 0

obs_agents_state[agent_virtual_position][0] = agent.direction
Expand Down Expand Up @@ -696,15 +690,6 @@ def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarra
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction

def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""

return super().get_many(handles)

def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
Expand Down
15 changes: 6 additions & 9 deletions flatland/envs/rail_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional, Dict, Tuple

import numpy as np
from flatland.utils import seeding

# from flatland.envs.timetable_generators import timetable_generator
import flatland.envs.timetable_generators as ttg
Expand All @@ -27,6 +26,7 @@
from flatland.envs.step_utils import env_utils
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.utils import seeding
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \
enable_infrastructure_lru_cache
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
Expand Down Expand Up @@ -192,8 +192,6 @@ def __init__(self,
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)

self.action_space = [5]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal safe? Remove from flatland.core.Environment.env as well or redefine there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does mean removal safe - remove from ... ? You like to delete the file or just action_space ?


self._seed(seed=random_seed)

self.agent_positions = None
Expand All @@ -220,8 +218,8 @@ def _seed(self, seed):
return [seed]

# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_agent_handles(self) -> List[int]:
return list(range(self.get_num_agents()))

def get_num_agents(self) -> int:
return len(self.agents)
Expand Down Expand Up @@ -500,7 +498,7 @@ def handle_done_state(self, agent):
if self.remove_agents_at_target:
agent.position = None

def step(self, action_dict_: Dict[int, RailEnvActions]):
def step(self, action_dict: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
Expand All @@ -526,7 +524,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]):
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)

# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)

preprocessed_action = self.preprocess_action(action, agent)

Expand Down Expand Up @@ -629,7 +627,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]):

self._update_agent_positions_map()
if self.record_steps:
self.record_timestep(action_dict_)
self.record_timestep(action_dict)

return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()

Expand Down Expand Up @@ -727,7 +725,6 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)

def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
Expand Down
Empty file added flatland/ml/__init__.py
Empty file.
Empty file.
156 changes: 156 additions & 0 deletions flatland/ml/observations/flatten_tree_observation_for_rail_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Optional, List

import gymnasium as gym
import numpy as np

from flatland.envs.observations import TreeObsForRailEnv
from flatland.ml.observations.gym_observation_builder import GymObservationBuilder


# from https://github.com/aiAdrian/flatland_solver_policy/blob/main/observation/flatland/flatten_tree_observation_for_rail_env/flatten_tree_observation_for_rail_env_utils.py
chenkins marked this conversation as resolved.
Show resolved Hide resolved
# initially from https://github.com/instadeepai/Mava/blob/0.0.9/mava/wrappers/flatland.py
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
max = 0
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max


def min_gt(seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq) - 1
while idx >= 0:
if seq[idx] >= val and seq[idx] < min:
min = seq[idx]
idx -= 1
return min


def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
"""
This function returns the difference between min and max value of an observation
:param obs: Observation that should be normalized
:param clip_min: min value where observation will be clipped
:param clip_max: max value where observation will be clipped
:return: returnes normalized and clipped observatoin
"""
if fixed_radius > 0:
max_obs = fixed_radius
else:
max_obs = max(1, max_lt(obs, 1000)) + 1

min_obs = 0 # min(max_obs, min_gt(obs, 0))
if normalize_to_range:
min_obs = min_gt(obs, 0)
if min_obs > max_obs:
min_obs = max_obs
if max_obs == min_obs:
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)


def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
data = np.zeros(6)
distance = np.zeros(1)
agent_data = np.zeros(4)

data[0] = node.dist_own_target_encountered
data[1] = node.dist_other_target_encountered
data[2] = node.dist_other_agent_encountered
data[3] = node.dist_potential_conflict
data[4] = node.dist_unusable_switch
data[5] = node.dist_to_next_branch

distance[0] = node.dist_min_to_target

agent_data[0] = node.num_agents_same_direction
agent_data[1] = node.num_agents_opposite_direction
agent_data[2] = node.num_agents_malfunctioning
agent_data[3] = node.speed_min_fractional

return data, distance, agent_data


def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (
np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4

data, distance, agent_data = _split_node_into_feature_groups(node)

if not node.childs:
return data, distance, agent_data

for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction],
current_tree_depth + 1,
max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))

return data, distance, agent_data


def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
"""
This function splits the tree into three difference arrays of values
"""
data, distance, agent_data = _split_node_into_feature_groups(tree)

for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1,
max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))

return data, distance, agent_data


def normalize_observation(observation, tree_depth: int, observation_radius=0):
"""
This function normalizes the observation used by the RL algorithm
"""
data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)

data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1)
normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
return normalized_obs

# TODO passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
# TODO can we not use gym flatteners instead?
# TODO call it ...Gym as well?
class FlattenTreeObsForRailEnv(TreeObsForRailEnv, GymObservationBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.observation_radius = 2

def get_many(self, handles: Optional[List[int]] = None):
obs = super(FlattenTreeObsForRailEnv, self).get_many(handles)
obs = {i: normalize_observation(obs[i], tree_depth=self.max_depth, observation_radius=self.observation_radius) for i in range(len(handles))}
return obs

def get_observation_space(self, handle: int = 0):
# max_depth=1 -> 55, max_depth=2 -> 231, max_depth=3 -> 935, ...
k = 11
for _ in range(self.max_depth):
k = k * 4 + 11
return gym.spaces.Box(-1, 2, (k,), dtype=np.float64)
Loading