Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

74 Policy evaluation and training cli (rllib) #85

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
68f7394
Add flatland training cli with ray.
chenkins Nov 1, 2024
f62da96
Fix error from ray passive_env_checker.
chenkins Nov 6, 2024
6f36350
Fix flattened gym observation builders.
chenkins Nov 6, 2024
63a65f3
Unit test for (gym) observation builder returned (d)types sizes/shapes.
chenkins Nov 6, 2024
7e682de
Extract ray examples.
chenkins Nov 11, 2024
6abc591
Update TODO.
chenkins Nov 16, 2024
3f903ca
Add ray address cli param.
chenkins Nov 19, 2024
3fe01dc
Cleanup training cli example.
chenkins Dec 6, 2024
e9d389b
Create test stub for env_creator.
chenkins Dec 6, 2024
ecf6deb
Mark training test as slow.
chenkins Dec 6, 2024
5c7a17e
Gym obs builder unit tests.
chenkins Dec 6, 2024
73d6bd2
Code cleanup gym obs builders.
chenkins Dec 6, 2024
37fe60f
Code cleanup gym obs builders.
chenkins Dec 6, 2024
113e719
Add option checkpointing.
chenkins Dec 7, 2024
b5e4681
Add rllib_demo.ipynb.
chenkins Dec 7, 2024
e2a758c
Remove obsolete ray shutdown.
chenkins Dec 9, 2024
952688d
Add rllib_demo.ipynb.
chenkins Dec 9, 2024
8b1f65a
Cleanup example cli interface.
chenkins Dec 14, 2024
a778e8c
Fix FlattenTreeObservation to contain all 12 features.
chenkins Dec 14, 2024
043c412
Add TreeObs TODOs.
chenkins Dec 14, 2024
ab29664
Update TODOs.
chenkins Dec 15, 2024
9acf958
Add regression test for tree obs.
chenkins Dec 20, 2024
ea222fa
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 20, 2024
e07eccc
Split flattening and normalization in FlattenTreeObsForRailEnv.
chenkins Dec 21, 2024
2af24f0
Update TODOs.
chenkins Dec 22, 2024
f3547b7
Add documentation on feature groups in tree obs flattening.
chenkins Dec 22, 2024
058d349
Refactor tree obs normalization.
chenkins Dec 22, 2024
63de68c
Refactor tree obs normalization.
chenkins Dec 22, 2024
3cea4ce
Cleanup.
chenkins Dec 22, 2024
1c4f41b
Apply suggestions from code review
chenkins Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions flatland/core/env_observation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@
multi-agent environments.

"""
from typing import Optional, List
from typing import Optional, List, Dict, Generic, TypeVar

import numpy as np

from flatland.core.env import Environment

ObservationType = TypeVar('ObservationType')
AgentHandle = int

class ObservationBuilder:

class ObservationBuilder(Generic[ObservationType]):
"""
ObservationBuilder base class.
"""

def __init__(self):
self.env = None
self.env: Optional[Environment] = None

def set_env(self, env: Environment):
self.env: Environment = env
Expand All @@ -32,7 +35,7 @@ def reset(self):
"""
raise NotImplementedError()

def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[AgentHandle]] = None) -> Dict[AgentHandle, ObservationType]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
Expand All @@ -55,7 +58,7 @@ def get_many(self, handles: Optional[List[int]] = None):
observations[h] = self.get(h)
return observations

def get(self, handle: int = 0):
def get(self, handle: AgentHandle = 0) -> ObservationType:
"""
Called whenever an observation has to be computed for the `env` environment, possibly
for each agent independently (agent id `handle`).
Expand All @@ -72,14 +75,14 @@ def get(self, handle: int = 0):
"""
raise NotImplementedError()

def _get_one_hot_for_agent_direction(self, agent):
def _get_one_hot_for_agent_direction(self, agent) -> np.ndarray:
"""Retuns the agent's direction to one-hot encoding."""
direction = np.zeros(4)
direction[agent.direction] = 1
return direction


class DummyObservationBuilder(ObservationBuilder):
class DummyObservationBuilder(ObservationBuilder[bool]):
"""
DummyObservationBuilder class which returns dummy observations
This is used in the evaluation service
Expand All @@ -91,8 +94,5 @@ def __init__(self):
def reset(self):
pass

def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True

def get(self, handle: int = 0) -> bool:
def get(self, handle: AgentHandle = 0) -> bool:
return True
Empty file.
46 changes: 46 additions & 0 deletions flatland/env_generation/env_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import ParamMalfunctionGen, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator


# defaults from Flatland 3 Round 2 Test_0, see https://flatland.aicrowd.com/challenges/flatland3/envconfig.html
def env_creator(n_agents=7,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use env_generator to use same naming as for rail and line generators?

x_dim=30,
y_dim=30,
n_cities=2,
max_rail_pairs_in_city=4,
grid_mode=False,
max_rails_between_cities=2,
malfunction_duration_min=20,
malfunction_duration_max=50,
malfunction_interval=540,
speed_ratios=None,
seed=42,
obs_builder_object=None) -> RailEnv:
if speed_ratios is None:
speed_ratios = {1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25}
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))

env = RailEnv(
width=x_dim,
height=y_dim,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=grid_mode,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_pairs_in_city
),
malfunction_generator=ParamMalfunctionGen(MalfunctionParameters(
min_duration=malfunction_duration_min, max_duration=malfunction_duration_max, malfunction_rate=1.0 / malfunction_interval)),
line_generator=sparse_line_generator(speed_ratio_map=speed_ratios, seed=seed),
number_of_agents=n_agents,
obs_builder_object=obs_builder_object,
record_steps=True
)
env.reset(random_seed=seed)
return env
27 changes: 6 additions & 21 deletions flatland/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_observation_builder import ObservationBuilder, AgentHandle
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
Expand All @@ -31,7 +31,7 @@
'childs')


class TreeObsForRailEnv(ObservationBuilder):
class TreeObsForRailEnv(ObservationBuilder[Node]):
"""
TreeObsForRailEnv object.

Expand All @@ -56,7 +56,7 @@ def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
def reset(self):
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}

def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
def get_many(self, handles: Optional[List[AgentHandle]] = None) -> Dict[AgentHandle, Node]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:

return observations

def get(self, handle: int = 0) -> Node:
def get(self, handle: AgentHandle = 0) -> Node:
"""
Computes the current observation for agent `handle` in env

Expand Down Expand Up @@ -211,7 +211,6 @@ def get(self, handle: int = 0) -> Node:
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()

# was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
Expand Down Expand Up @@ -440,7 +439,6 @@ def _explore_branch(self, handle, position, direction, tot_dist, depth):
dist_to_next_branch = tot_dist
dist_min_to_target = distance_map_handle[position[0], position[1], direction]

# TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
Expand Down Expand Up @@ -532,7 +530,7 @@ def _reverse_dir(self, direction):
return int((direction + 2) % 4)


class GlobalObsForRailEnv(ObservationBuilder):
class GlobalObsForRailEnv(ObservationBuilder[Tuple[np.ndarray, np.ndarray, np.ndarray]]):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
Expand Down Expand Up @@ -565,7 +563,7 @@ def reset(self):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)

def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
def get(self, handle: AgentHandle = 0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

agent = self.env.agents[handle]
if agent.state.is_off_map_state():
Expand All @@ -580,10 +578,6 @@ def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1

# TODO can we do this more elegantly?
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:, :, 4] = 0

obs_agents_state[agent_virtual_position][0] = agent.direction
Expand Down Expand Up @@ -696,15 +690,6 @@ def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarra
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction

def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""

return super().get_many(handles)

def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
Expand Down
15 changes: 6 additions & 9 deletions flatland/envs/rail_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional, Dict, Tuple

import numpy as np
from flatland.utils import seeding

# from flatland.envs.timetable_generators import timetable_generator
import flatland.envs.timetable_generators as ttg
Expand All @@ -27,6 +26,7 @@
from flatland.envs.step_utils import env_utils
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.utils import seeding
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \
enable_infrastructure_lru_cache
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
Expand Down Expand Up @@ -192,8 +192,6 @@ def __init__(self,
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)

self.action_space = [5]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal safe? Remove from flatland.core.Environment.env as well or redefine there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does mean removal safe - remove from ... ? You like to delete the file or just action_space ?


self._seed(seed=random_seed)

self.agent_positions = None
Expand All @@ -220,8 +218,8 @@ def _seed(self, seed):
return [seed]

# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_agent_handles(self) -> List[int]:
return list(range(self.get_num_agents()))

def get_num_agents(self) -> int:
return len(self.agents)
Expand Down Expand Up @@ -500,7 +498,7 @@ def handle_done_state(self, agent):
if self.remove_agents_at_target:
agent.position = None

def step(self, action_dict_: Dict[int, RailEnvActions]):
def step(self, action_dict: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
Expand All @@ -526,7 +524,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]):
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)

# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)

preprocessed_action = self.preprocess_action(action, agent)

Expand Down Expand Up @@ -629,7 +627,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]):

self._update_agent_positions_map()
if self.record_steps:
self.record_timestep(action_dict_)
self.record_timestep(action_dict)

return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()

Expand Down Expand Up @@ -727,7 +725,6 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)

def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
Expand Down
Empty file added flatland/ml/__init__.py
Empty file.
Empty file.
Loading