Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reward engineering #13

Merged
merged 10 commits into from
May 3, 2022
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ train_args.json
*-ignition-points.json

stable-baselines3-contrib
*.gif
slurm*.out

34 changes: 27 additions & 7 deletions cell2fire/evaluate_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import os
import time
from typing import Optional

from generate_ignition_points import load_ignition_points
Expand All @@ -10,13 +11,15 @@

from cell2fire.gym_env import FireEnv
from firehose.baselines import (
HumanExpertAlgorithm,
HumanInputAlgorithm,
NaiveAlgorithm,
NoAlgorithm,
RandomAlgorithm,
)
from firehose.helpers import IgnitionPoint, IgnitionPoints
from firehose.results import FirehoseResults
from firehose.rewards import REWARD_FUNCTIONS
from firehose.video_recorder import FirehoseVideoRecorder
import numpy as np

Expand Down Expand Up @@ -51,6 +54,7 @@
"random": RandomAlgorithm,
"naive": NaiveAlgorithm,
"human": HumanInputAlgorithm,
"expert": HumanExpertAlgorithm,
"none": NoAlgorithm,
}

Expand Down Expand Up @@ -122,7 +126,7 @@ def main(args):
# Get the model for the algorithm and setup video recorder
model = _get_model(algo=args.algo, model_path=args.model_path, env=env)
video_recorder = FirehoseVideoRecorder(
env, algo=args.algo, disable_video=args.disable_video
env, args=args, disable_video=args.disable_video
)

# Override observation type if required - this is for maskable PPO mostly
Expand Down Expand Up @@ -174,6 +178,8 @@ def get_action():
if not args.disable_render:
env.render()
video_recorder.capture_frame()
if args.delay:
time.sleep(args.delay)

if reward is None:
raise RuntimeError("Reward is None. This should not happen")
Expand Down Expand Up @@ -246,21 +252,35 @@ def get_action():
parser.add_argument(
"-acd", "--action_diameter", default=1, type=int, help="Action diameter"
)
parser.add_argument(
"-i",
"--ignition_type",
default="random",
help="Specifies whether to use a random or fixed fire ignition point."
"Choices: fixed, random, or specify path to a ignition point JSON file",
)
parser.add_argument(
"-r",
"--reward",
default="FireSizeReward",
help="Specifies the reward function to use",
choices=set(REWARD_FUNCTIONS.keys()),
)
parser.add_argument(
"--disable-video", action="store_true", help="Disable video recording"
)
parser.add_argument(
"-d", "--disable-render", action="store_true", help="Disable cv2 rendering"
)
parser.add_argument(
"-pr", "--parallel-record", action="store_true", help="Disable cv2 rendering"
"--delay",
default=0.0,
type=float,
help="Delay between steps in simulation. For visualization purposes "
"- note: it doesn't get reflected in the video",
)
parser.add_argument(
"-i",
"--ignition_type",
default="random",
help="Specifies whether to use a random or fixed fire ignition point."
"Choices: fixed, random, or specify path to a ignition point JSON file",
"-pr", "--parallel-record", action="store_true", help="Disable cv2 rendering"
)
parser.add_argument(
"-o",
Expand Down
23 changes: 19 additions & 4 deletions cell2fire/firehose/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def predict(self, obs, **kwargs) -> Tuple[Any, Any]:
class NaiveAlgorithm(FlatActionSpaceAlgorithm):
"""
The Naive algorithm selects the cell that is on fire that is closest to the
ignition point in terms of Euclidean distance.
ignition point in terms of Euclidean distance if use_min is True.

If use_min is False, then we select the point furthest from the ignition.
This is our Frontier baseline essentially.

If cells have already been put out, then it will not consider them.
If there are no cells on fire, then it will return -1 (no-op).
Expand All @@ -77,9 +80,10 @@ def __init__(self, env: FireEnv):
self.prev_actions: Set[int] = {-1}
self.ignition_point = self.env.ignition_points.points[0]
self.ignition_point_yx = self.env.flatten_idx_to_yx[self.ignition_point.idx - 1]
self.use_min = True

def _update_ignition_point(self):
""" Update ignition point if it has changed, indicating a reset in the environment """
"""Update ignition point if it has changed, indicating a reset in the environment"""
current_ignition_point = self.env.ignition_points.points[0]

if current_ignition_point != self.ignition_point:
Expand Down Expand Up @@ -113,8 +117,13 @@ def predict(self, obs, **kwargs) -> Tuple[Any, Any]:
# No cells on fire so no-op
return -1, None

# Choose closest cell on fire
closest_idx = np.argmin(dist)
if self.use_min:
# Choose closest cell on fire
closest_idx = np.argmin(dist)
else:
# Choose furthest cell on fire
closest_idx = np.argmax(dist)

chosen_fire_yx = fire_yx[closest_idx]
chosen_fire_idx = self.env.yx_to_flatten_idx[chosen_fire_yx]

Expand All @@ -134,3 +143,9 @@ def predict(self, obs, **kwargs) -> Tuple[Any, Any]:

self.prev_actions.add(chosen_fire_idx)
return chosen_fire_idx, None


class HumanExpertAlgorithm(NaiveAlgorithm):
def __init__(self, env: FireEnv):
super().__init__(env)
self.use_min = False
10 changes: 8 additions & 2 deletions cell2fire/firehose/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forest_data(self) -> np.ndarray:

@cached_property
def reward_data(self) -> np.ndarray:
if(exists(self.reward_datafile)):
if exists(self.reward_datafile):
return np.loadtxt(self.reward_datafile, skiprows=6)
else:
return None
Expand Down Expand Up @@ -167,7 +167,7 @@ def manipulate_input_data_folder(
print(f"Copied modified input data folder to {tmp_dir}")

def overwrite_ignition_points(self, ignition_points: IgnitionPoints):
""" Overwrite ignition points CSV file """
"""Overwrite ignition points CSV file"""
tmp_dir = self.tmp_input_folder
ignition_points_csv = os.path.join(tmp_dir, IgnitionPoints.CSV_NAME)
# Only remove ignitions if it already exists
Expand Down Expand Up @@ -221,3 +221,9 @@ def generate_random_ignition_points(
)
# print("Sampled ignition points:", ignition_points)
return ignition_points

def teardown(self):
"""Delete temporary input folder"""
if self.tmp_input_folder:
shutil.rmtree(self.tmp_input_folder)
print(f"Deleted {self.tmp_input_folder}")
25 changes: 19 additions & 6 deletions cell2fire/firehose/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,26 @@ class WillShenReward(Reward):
def name(cls) -> str:
return "WillShenReward"

def __call__(self, scale: float = 10, action: int = -1, **kwargs) -> float:
assert self.env.state.shape == self.env.forest_image.shape[:2]

def __call__(
self,
fire_scale: float = 10,
dist_scale: float = 0.05,
action: int = -1,
run_asserts: bool = False,
**kwargs
) -> float:
fire_idxs = np.array(np.where(self.env.state > 0)).T
num_cells_on_fire = fire_idxs.shape[0] if fire_idxs.size != 0 else 0

# Proportion of cells on fire
fire_term = 1 - num_cells_on_fire / self.env.num_cells
# -(num cells on fire) / (total num cells in forest) * scale
fire_term = -num_cells_on_fire / self.env.num_cells * fire_scale

# Some sanity checks
if run_asserts:
assert self.env.state.shape == self.env.forest_image.shape[:2]
assert fire_idxs.shape[1] == 2
fire_reward = FireSizeReward(self.env)
assert fire_term == fire_reward()

# Hack that will suffice for 2x2 and 3x3
if isinstance(action, list):
Expand All @@ -81,9 +93,10 @@ def __call__(self, scale: float = 10, action: int = -1, **kwargs) -> float:

# Penalize actions far away from fire
action_dist_term = min_dist_to_fire / self.env.max_dist
scaled_action_dist_term = action_dist_term * dist_scale

# % of cells on fire - min dist to fire / scale
reward = fire_term - action_dist_term / scale
reward = fire_term - scaled_action_dist_term
return reward


Expand Down
4 changes: 2 additions & 2 deletions cell2fire/firehose/video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class FirehoseVideoRecorder:
"""Wrapper around gym VideoRecorder to record firehose environment."""

def __init__(self, env: FireEnv, algo: str, disable_video: bool = False):
def __init__(self, env: FireEnv, args, disable_video: bool = False):
if disable_video:
self.video_recorder = None
else:
Expand All @@ -17,7 +17,7 @@ def __init__(self, env: FireEnv, algo: str, disable_video: bool = False):
if not os.path.exists("videos"):
os.mkdir("videos")

video_fname = f"videos/{algo}-{date_str}.mp4"
video_fname = f"videos/{args.algo}-{args.map}-{date_str}.mp4"
self.video_recorder = VideoRecorder(env, video_fname, enabled=True)

def capture_frame(self):
Expand Down
25 changes: 19 additions & 6 deletions cell2fire/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
# Image of forest which we overlay
self.forest_image = self.helper.forest_image
self.uforest_image = (self.forest_image * 255).astype("uint8")
if(self.helper.reward_data is None):

if self.helper.reward_data is None:
self.reward_mask = None
else:
self.reward_mask = np.where(self.helper.reward_data > 0)
Expand Down Expand Up @@ -143,7 +143,9 @@ def __init__(
self.yx_to_flatten_idx: Dict[Tuple[int, int], int] = {
yx: idx for idx, yx, in self.flatten_idx_to_yx.items()
}
min_yx, max_yx = np.array(min(self.yx_to_flatten_idx)), np.array(max(self.yx_to_flatten_idx))
min_yx, max_yx = np.array(min(self.yx_to_flatten_idx)), np.array(
max(self.yx_to_flatten_idx)
)
self.max_dist = np.linalg.norm(max_yx - min_yx)

# Note: Reward function. Call this at end of __init__ just so we're safe
Expand Down Expand Up @@ -182,7 +184,10 @@ def _set_observation_space(self):
# Forest as a RGB image
# TODO: should we normalize the RGB image? At least divide by 255 so its [0, 1]
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width, 3), dtype=np.uint8,
low=0,
high=255,
shape=(self.height, self.width, 3),
dtype=np.uint8,
)
elif self.observation_type == "forest":
# Forest as -1 (harvested), 0 (nothing), 1 (on fire)
Expand All @@ -192,7 +197,10 @@ def _set_observation_space(self):
elif self.observation_space == "time":
# Blind model
self.observation_space = spaces.Box(
low=0, high=self.max_steps + 1, shape=(1,), dtype=np.uint8,
low=0,
high=self.max_steps + 1,
shape=(1,),
dtype=np.uint8,
)
else:
raise ValueError(f"Unsupported observation type {self.observation_type}")
Expand Down Expand Up @@ -282,7 +290,7 @@ def action_masks(self):
return mask

def _update_counters(self):
""" Update the counters based on the current state of the forest"""
"""Update the counters based on the current state of the forest"""
harvested = set(zip(*np.where(self.state == -1)))
on_fire = set(zip(*np.where(self.state == 1)))

Expand Down Expand Up @@ -462,6 +470,11 @@ def reset(self, **kwargs):

return self.get_observation()

def close(self):
"""Clean up after ourselves"""
self.helper.teardown()
super().close()


def main(debug: bool, delay_time: float = 0.0, **env_kwargs):
env = FireEnv(**env_kwargs, verbose=debug)
Expand Down
7 changes: 1 addition & 6 deletions cell2fire/rl_experiment_vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_model(self, env) -> Model:
print(
f"Overrode tensorboard log dir from {old_tf_logdir} to {self.tf_logdir}"
)
elif args.algo in {"ppo", "a2c", "trpo", "ppo-maskable"}:
elif args.algo in {"ppo", "a2c", "trpo", "ppo-maskable", "dqn"}:
# If no reload specified then just create a new model
model_cls = SB3_ALGO_TO_MODEL_CLASS[args.algo]
model = model_cls(
Expand All @@ -165,11 +165,6 @@ def _get_model(self, env) -> Model:
gamma=args.gamma,
policy_kwargs=self.model_kwargs,
)
elif args.algo == "dqn":
# DQN doesn't support gamma so handle separately
model = DQN(
args.architecture, env, verbose=1, tensorboard_log=self.tf_logdir
)
else:
raise NotImplementedError

Expand Down
Loading