-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
74 Policy evaluation and training cli (rllib) #85
base: main
Are you sure you want to change the base?
Changes from 1 commit
68f7394
f62da96
6f36350
63a65f3
7e682de
6abc591
3f903ca
3fe01dc
e9d389b
ecf6deb
5c7a17e
73d6bd2
37fe60f
113e719
b5e4681
e2a758c
952688d
8b1f65a
a778e8c
043c412
ab29664
9acf958
ea222fa
e07eccc
2af24f0
f3547b7
058d349
63de68c
3cea4ce
1c4f41b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from flatland.envs.line_generators import sparse_line_generator | ||
from flatland.envs.malfunction_generators import ParamMalfunctionGen, MalfunctionParameters | ||
from flatland.envs.observations import TreeObsForRailEnv | ||
from flatland.envs.predictions import ShortestPathPredictorForRailEnv | ||
from flatland.envs.rail_env import RailEnv | ||
from flatland.envs.rail_generators import sparse_rail_generator | ||
|
||
|
||
# defaults from Flatland 3 Round 2 Test_0, see https://flatland.aicrowd.com/challenges/flatland3/envconfig.html | ||
def env_creator(n_agents=7, | ||
x_dim=30, | ||
y_dim=30, | ||
n_cities=2, | ||
max_rail_pairs_in_city=4, | ||
grid_mode=False, | ||
max_rails_between_cities=2, | ||
malfunction_duration_min=20, | ||
malfunction_duration_max=50, | ||
malfunction_interval=540, | ||
speed_ratios=None, | ||
seed=42, | ||
obs_builder_object=None) -> RailEnv: | ||
if speed_ratios is None: | ||
speed_ratios = {1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25} | ||
if obs_builder_object is None: | ||
obs_builder_object = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50)) | ||
|
||
return RailEnv( | ||
width=x_dim, | ||
height=y_dim, | ||
rail_generator=sparse_rail_generator( | ||
max_num_cities=n_cities, | ||
seed=seed, | ||
grid_mode=grid_mode, | ||
max_rails_between_cities=max_rails_between_cities, | ||
max_rail_pairs_in_city=max_rail_pairs_in_city | ||
), | ||
malfunction_generator=ParamMalfunctionGen(MalfunctionParameters( | ||
min_duration=malfunction_duration_min, max_duration=malfunction_duration_max, malfunction_rate=1.0 / malfunction_interval)), | ||
line_generator=sparse_line_generator(speed_ratio_map=speed_ratios, seed=seed), | ||
number_of_agents=n_agents, | ||
obs_builder_object=obs_builder_object, | ||
record_steps=True | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
from typing import List, Optional, Dict, Tuple | ||
|
||
import numpy as np | ||
from flatland.utils import seeding | ||
|
||
# from flatland.envs.timetable_generators import timetable_generator | ||
import flatland.envs.timetable_generators as ttg | ||
|
@@ -27,6 +26,7 @@ | |
from flatland.envs.step_utils import env_utils | ||
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals | ||
from flatland.envs.step_utils.transition_utils import check_valid_action | ||
from flatland.utils import seeding | ||
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \ | ||
enable_infrastructure_lru_cache | ||
from flatland.utils.rendertools import RenderTool, AgentRenderVariant | ||
|
@@ -192,8 +192,6 @@ def __init__(self, | |
self.num_resets = 0 | ||
self.distance_map = DistanceMap(self.agents, self.height, self.width) | ||
|
||
self.action_space = [5] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this removal safe? Remove from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does mean removal safe - remove from ... ? You like to delete the file or just action_space ? |
||
|
||
self._seed(seed=random_seed) | ||
|
||
self.agent_positions = None | ||
|
@@ -220,8 +218,8 @@ def _seed(self, seed): | |
return [seed] | ||
|
||
# no more agent_handles | ||
def get_agent_handles(self): | ||
return range(self.get_num_agents()) | ||
def get_agent_handles(self) -> List[int]: | ||
return list(range(self.get_num_agents())) | ||
|
||
def get_num_agents(self) -> int: | ||
return len(self.agents) | ||
|
@@ -500,7 +498,7 @@ def handle_done_state(self, agent): | |
if self.remove_agents_at_target: | ||
agent.position = None | ||
|
||
def step(self, action_dict_: Dict[int, RailEnvActions]): | ||
def step(self, action_dict: Dict[int, RailEnvActions]): | ||
""" | ||
Updates rewards for the agents at a step. | ||
""" | ||
|
@@ -526,7 +524,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]): | |
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random) | ||
|
||
# Get action for the agent | ||
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING) | ||
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING) | ||
|
||
preprocessed_action = self.preprocess_action(action, agent) | ||
|
||
|
@@ -629,7 +627,7 @@ def step(self, action_dict_: Dict[int, RailEnvActions]): | |
|
||
self._update_agent_positions_map() | ||
if self.record_steps: | ||
self.record_timestep(action_dict_) | ||
self.record_timestep(action_dict) | ||
|
||
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() | ||
|
||
|
@@ -727,7 +725,6 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar | |
return self.update_renderer(mode=mode, show=show, show_observations=show_observations, | ||
show_predictions=show_predictions, | ||
show_rowcols=show_rowcols, return_image=return_image) | ||
|
||
def initialize_renderer(self, mode, gl, | ||
agent_render_variant, | ||
show_debug, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from typing import Optional, List | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
|
||
from flatland.envs.observations import TreeObsForRailEnv | ||
from flatland.ml.observations.gym_observation_builder import GymObservationBuilder | ||
|
||
|
||
# from https://github.com/aiAdrian/flatland_solver_policy/blob/main/observation/flatland/flatten_tree_observation_for_rail_env/flatten_tree_observation_for_rail_env_utils.py | ||
chenkins marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# initially from https://github.com/instadeepai/Mava/blob/0.0.9/mava/wrappers/flatland.py | ||
def max_lt(seq, val): | ||
""" | ||
Return greatest item in seq for which item < val applies. | ||
None is returned if seq was empty or all items in seq were >= val. | ||
""" | ||
max = 0 | ||
idx = len(seq) - 1 | ||
while idx >= 0: | ||
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: | ||
max = seq[idx] | ||
idx -= 1 | ||
return max | ||
|
||
|
||
def min_gt(seq, val): | ||
""" | ||
Return smallest item in seq for which item > val applies. | ||
None is returned if seq was empty or all items in seq were >= val. | ||
""" | ||
min = np.inf | ||
idx = len(seq) - 1 | ||
while idx >= 0: | ||
if seq[idx] >= val and seq[idx] < min: | ||
min = seq[idx] | ||
idx -= 1 | ||
return min | ||
|
||
|
||
def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False): | ||
""" | ||
This function returns the difference between min and max value of an observation | ||
:param obs: Observation that should be normalized | ||
:param clip_min: min value where observation will be clipped | ||
:param clip_max: max value where observation will be clipped | ||
:return: returnes normalized and clipped observatoin | ||
""" | ||
if fixed_radius > 0: | ||
max_obs = fixed_radius | ||
else: | ||
max_obs = max(1, max_lt(obs, 1000)) + 1 | ||
|
||
min_obs = 0 # min(max_obs, min_gt(obs, 0)) | ||
if normalize_to_range: | ||
min_obs = min_gt(obs, 0) | ||
if min_obs > max_obs: | ||
min_obs = max_obs | ||
if max_obs == min_obs: | ||
return np.clip(np.array(obs) / max_obs, clip_min, clip_max) | ||
norm = np.abs(max_obs - min_obs) | ||
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) | ||
|
||
|
||
def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray): | ||
data = np.zeros(6) | ||
distance = np.zeros(1) | ||
agent_data = np.zeros(4) | ||
|
||
data[0] = node.dist_own_target_encountered | ||
data[1] = node.dist_other_target_encountered | ||
data[2] = node.dist_other_agent_encountered | ||
data[3] = node.dist_potential_conflict | ||
data[4] = node.dist_unusable_switch | ||
data[5] = node.dist_to_next_branch | ||
|
||
distance[0] = node.dist_min_to_target | ||
|
||
agent_data[0] = node.num_agents_same_direction | ||
agent_data[1] = node.num_agents_opposite_direction | ||
agent_data[2] = node.num_agents_malfunctioning | ||
agent_data[3] = node.speed_min_fractional | ||
|
||
return data, distance, agent_data | ||
|
||
|
||
def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> ( | ||
np.ndarray, np.ndarray, np.ndarray): | ||
if node == -np.inf: | ||
remaining_depth = max_tree_depth - current_tree_depth | ||
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure | ||
num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) | ||
return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4 | ||
|
||
data, distance, agent_data = _split_node_into_feature_groups(node) | ||
|
||
if not node.childs: | ||
return data, distance, agent_data | ||
|
||
for direction in TreeObsForRailEnv.tree_explored_actions_char: | ||
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], | ||
current_tree_depth + 1, | ||
max_tree_depth) | ||
data = np.concatenate((data, sub_data)) | ||
distance = np.concatenate((distance, sub_distance)) | ||
agent_data = np.concatenate((agent_data, sub_agent_data)) | ||
|
||
return data, distance, agent_data | ||
|
||
|
||
def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray): | ||
""" | ||
This function splits the tree into three difference arrays of values | ||
""" | ||
data, distance, agent_data = _split_node_into_feature_groups(tree) | ||
|
||
for direction in TreeObsForRailEnv.tree_explored_actions_char: | ||
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, | ||
max_tree_depth) | ||
data = np.concatenate((data, sub_data)) | ||
distance = np.concatenate((distance, sub_distance)) | ||
agent_data = np.concatenate((agent_data, sub_agent_data)) | ||
|
||
return data, distance, agent_data | ||
|
||
|
||
def normalize_observation(observation, tree_depth: int, observation_radius=0): | ||
""" | ||
This function normalizes the observation used by the RL algorithm | ||
""" | ||
data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth) | ||
|
||
data = norm_obs_clip(data, fixed_radius=observation_radius) | ||
distance = norm_obs_clip(distance, normalize_to_range=True) | ||
agent_data = np.clip(agent_data, -1, 1) | ||
normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data)) | ||
return normalized_obs | ||
|
||
|
||
class FlattenTreeObsForRailEnv(TreeObsForRailEnv, GymObservationBuilder): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.observation_radius = 2 | ||
|
||
def get_many(self, handles: Optional[List[int]] = None): | ||
obs = super(FlattenTreeObsForRailEnv, self).get_many(handles) | ||
obs = {i: normalize_observation(obs[i], tree_depth=self.max_depth, observation_radius=self.observation_radius) for i in range(len(handles))} | ||
return obs | ||
|
||
def get_observation_space(self, handle: int = 0): | ||
# max_depth=1 -> 55, max_depth=2 -> 231, max_depth=3 -> 935, ... | ||
k = 11 | ||
for _ in range(self.max_depth): | ||
k = k * 4 + 11 | ||
return gym.spaces.Box(-1, 2, (k,), dtype=np.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use
env_generator
to use same naming as for rail and line generators?