Skip to content

Commit

Permalink
Merge pull request #12 from tobirohrer/feature/randomize-start-time-step
Browse files Browse the repository at this point in the history
Added `randomize_start_time_step` parameter
  • Loading branch information
tobirohrer authored Oct 12, 2023
2 parents fe7d2e9 + b591ecf commit 7a198a3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
16 changes: 13 additions & 3 deletions building_energy_storage_simulation/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,27 @@ class Environment(gym.Env):
:type num_forecasting_steps: int
:param building_simulation: Instance of `BuildingSimulation` to be wrapped as `gymnasium` environment.
:type building_simulation: BuildingSimulation
:param randomize_start_time_step: Randomizes the `start_index` in the `BuildingSimulation`. this should help prevent
the agent from overfitting to the data profile during training (otherwise it will always see the same time
series from the same start point)
:type randomize_start_time_step: bool
"""

def __init__(self,
building_simulation: BuildingSimulation,
max_timesteps: int = 2000,
num_forecasting_steps: int = 4,
randomize_start_time_step: bool = False
):

self.building_simulation = building_simulation
self.max_timesteps = max_timesteps
self.num_forecasting_steps = num_forecasting_steps
self.randomize_start_time_step = randomize_start_time_step
self.data_profile_length = len(self.building_simulation.solar_generation_profile)

assert self.max_timesteps + self.num_forecasting_steps < len(self.building_simulation.solar_generation_profile), \
"`max_timesteps` plus the forecast length cannot be greater than the length of the simulation profile."
assert self.max_timesteps + self.num_forecasting_steps <= self.data_profile_length, \
"`max_timesteps` plus the forecast length cannot be greater than the length of the data profiles."

self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
# Using np.inf as bounds as the observations must be rescaled externally anyways. E.g. Using the VecNormalize
Expand Down Expand Up @@ -61,6 +68,9 @@ def reset(self, seed=None, options=None) -> Tuple[ObsType, dict]:
"""

self.building_simulation.reset()
if self.randomize_start_time_step:
latest_possible_start_time_step = self.data_profile_length - self.max_timesteps - self.num_forecasting_steps
self.building_simulation.start_index = int(np.random.uniform(0, latest_possible_start_time_step))
return self.get_observation(), {}

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
Expand Down Expand Up @@ -97,7 +107,7 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
'electricity_price': electricity_price}

def _get_terminated(self):
if self.building_simulation.step_count > self.max_timesteps:
if self.building_simulation.step_count >= self.max_timesteps:
return True
return False

Expand Down
38 changes: 35 additions & 3 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from building_energy_storage_simulation import Environment
import pytest
import numpy as np


@pytest.fixture(scope='module')
Expand All @@ -23,10 +24,20 @@ def test_environment_noop_step(building_simulation):
assert initial_obs[0][2] == obs[0][1]


def test_terminated_at_timelimit_reached(building_simulation):
env = Environment(building_simulation=building_simulation, num_forecasting_steps=0, max_timesteps=9)
@pytest.mark.parametrize(
"data_profile_length, num_forecasting_steps", [(2, 1), (9, 0)]
)
def test_terminated_at_timelimit_reached(data_profile_length, num_forecasting_steps):
dummy_profile = np.zeros(data_profile_length)
building_sim = BuildingSimulation(electricity_price=dummy_profile,
solar_generation_profile=dummy_profile,
electricity_load_profile=dummy_profile)
env = Environment(building_simulation=building_sim,
num_forecasting_steps=num_forecasting_steps,
max_timesteps=data_profile_length - num_forecasting_steps)
env.reset()
for i in range(10):
print(range(data_profile_length - num_forecasting_steps))
for i in range(data_profile_length - num_forecasting_steps):
obs, reward, terminated, trunc, info = env.step(0)
assert terminated is True

Expand Down Expand Up @@ -69,3 +80,24 @@ def test_default_initialization_runs_without_throwing():
env.reset()
env.step(1)
assert env.building_simulation.step_count == 1


def test_set_random_first_time_step_always_0_for_data_profile_length_2():
dummy_profile = [0, 0]
sim = BuildingSimulation(electricity_price=dummy_profile,
electricity_load_profile=dummy_profile,
solar_generation_profile=dummy_profile)
env = Environment(sim, randomize_start_time_step=True, max_timesteps=1, num_forecasting_steps=1)
env.reset()
assert env.building_simulation.start_index == 0


def test_set_random_first_time_step():
dummy_profile = np.zeros(1000)
sim = BuildingSimulation(electricity_price=dummy_profile,
electricity_load_profile=dummy_profile,
solar_generation_profile=dummy_profile)
env = Environment(sim, randomize_start_time_step=True, max_timesteps=1, num_forecasting_steps=1)
env.reset()
# This test is very unlikely to fail ;)
assert env.building_simulation.start_index != 0

0 comments on commit 7a198a3

Please sign in to comment.