Skip to content

Commit

Permalink
Low visiblity survey and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Sep 20, 2023
1 parent 83bb778 commit 0846e87
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 23 deletions.
57 changes: 40 additions & 17 deletions telescope_positioning_simulation/Survey/cummulative_survey.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from telescope_positioning_simulation.Survey.survey import Survey
import numpy as np
import pandas as pd
import json


class CummulativeSurvey(Survey):
Expand All @@ -20,7 +21,11 @@ def step(self, action: dict):
observation_pd = {key: observation[key].ravel() for key in observation.keys()}

observation_pd = pd.DataFrame(observation_pd)
observation_pd["action"] = str(action["location"]) + str(action["band"])
observation_pd["action"] = str(
{"location": action["location"], "band": self.observator.band}
)
observation_pd["band"] = self.observator.band
observation_pd["location"] = str(action["location"])
observation_pd["reward"] = reward

self.all_steps = self.all_steps.append(observation_pd)
Expand All @@ -39,18 +44,6 @@ class UniformSurvey(CummulativeSurvey):
survey_config (dict): Parameters for the survey, including the stopping conditions, the validity conditions, the variables to collect, as read by IO.ReadConfig
threshold (float): Threshold the survey must pass to have its quality counted towards the total reward
uniform (str): ["site", "quality"] - If measuring the uniformity of the number of times each site has been visited, or the uniformity of the quality of observations
Examples:
>>> survey = Survey(observatory_config, survey_config)
action_generator = ActionGenerator() # Attributary function to produce time, location pairs
for step in range(10):
action_time, action_location = action_generator()
update_action = {"time":[action_time], "location":{"ra":[action_location["ra"]], "decl":[action_location["decl"]]}}
observation, reward, stop, log = survey.step(update_action)
>>> survey = Survey(observatory_config, survey_config)
# Run without changing the location, only stepping time forward
survey_results = survey()
"""

def __init__(
Expand Down Expand Up @@ -81,7 +74,9 @@ def site_reward(self):
self.all_steps["action"].isin(counts.index[counts < self.threshold]),
"reward",
] = 0
reward_sum = current_steps.groupby(["mjd", "action"])["reward"].sum().sum()
reward_sum = (
current_steps.groupby(["mjd", "action", "band"])["reward"].sum().sum()
)

return reward_scale * reward_sum

Expand All @@ -106,22 +101,50 @@ def _subclass_reward(self, *args, **kwargs):

class LowVisiblitySurvey(CummulativeSurvey):
def __init__(
self, observatory_config: dict, survey_config: dict, required_sites: dict = {}
self,
observatory_config: dict,
survey_config: dict,
required_sites: list = [],
other_site_weight: float = 0.6,
time_tolerance: float = 0.01388,
) -> None:
super().__init__(observatory_config, survey_config)

self.all_steps = pd.DataFrame()
self.required_sites = required_sites
self.time_tolerance = time_tolerance
self.weight = other_site_weight

def sites_hit(self):
# TODO do this with sets and arrays instead of a loop
hit_counter = 0
for site in self.required_sites:

subset = self.all_steps.copy()
if "time" in site.keys():
subset = subset[
(subset["mjd"] < site["time"][0] + self.time_tolerance)
& (subset["mjd"] > site["time"][0] - self.time_tolerance)
]

if "band" in site.keys():
subset = subset[subset["band"] == site["band"]]

subset = subset[subset["location"] == str(site["location"])]
hit_counter += len(subset)

return hit_counter

def _subclass_reward(self):
if len(self.all_steps) != 0:

reward_scale = 1 / len(self.all_steps)
weighted_term = self.weight * self.all_steps["reward"].sum()
number_of_interest_hit = ""
number_of_interest_hit = self.sites_hit()

reward = reward_scale * (weighted_term + number_of_interest_hit)
reward = reward if not (pd.isnull(reward) or reward == -np.inf) else 0

return reward

else:
return 0
155 changes: 149 additions & 6 deletions test/test_cummaltive_survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,180 @@
def test_uniform_site():
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
obs_config["start_time"] = 59946
uniform_survey = UniformSurvey(
observatory_config=obs_config, survey_config=ReadConfig(survey=True)()
)
assert len(uniform_survey.all_steps) == 0
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0

uniform_survey.step(action)
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0
assert len(uniform_survey.all_steps) == 1

uniform_survey.step(action_2)
assert len(uniform_survey.all_steps) == 2
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0


def test_uniform_quality():

obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
obs_config["start_time"] = 59946

uniform_survey = UniformSurvey(
observatory_config=obs_config,
survey_config=ReadConfig(survey=True)(),
uniform="quality",
)
assert len(uniform_survey.all_steps) == 0
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0

uniform_survey.step(action)
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0
assert len(uniform_survey.all_steps) == 1

uniform_survey.step(action_2)
assert len(uniform_survey.all_steps) == 2
assert uniform_survey._reward() == 0
assert uniform_survey._subclass_reward() == 0


# Todo - schedule that passes the conditions in the defaults


def test_lowvis_hit_required_sites():
required_sites = [
{"location": {"ra": [ra], "decl": [decl]}} for ra, decl in zip([0, 10], [0, 10])
]
expected_reward = 1
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
survey_config = ReadConfig(survey=True)()
survey_config["invalid_penality"] = 0

survey = LowVisiblitySurvey(
observatory_config=obs_config,
survey_config=survey_config,
required_sites=required_sites,
other_site_weight=0,
)
actions = [
{"location": {"ra": [0], "decl": [0]}, "band": "g"},
{"location": {"ra": [10], "decl": [10]}, "band": "g"},
]
for action in actions:
survey.step(action)

print(survey.all_steps)
assert survey._subclass_reward() >= expected_reward


def test_lowvis_hit_required_sites_correct_time():
required_sites = [
{"time": [time], "location": {"ra": [ra], "decl": [decl]}}
for ra, decl, time in zip([0, 10], [0, 10], [59946.08, 59946.1])
]

expected_reward = 1
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
survey_config = ReadConfig(survey=True)()
survey_config["invalid_penality"] = 0

survey = LowVisiblitySurvey(
observatory_config=obs_config,
survey_config=survey_config,
required_sites=required_sites,
other_site_weight=0,
)
for action in required_sites:
survey.step(action)

assert survey._subclass_reward() >= expected_reward


def test_lowvis_hit_required_sites_incorrect_time():
required_sites = [
{"time": [time], "location": {"ra": [ra], "decl": [decl]}}
for ra, decl, time in zip([0, 10], [0, 10], [59946.08, 59946.1])
]

expected_reward = 1
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
survey_config = ReadConfig(survey=True)()
survey_config["invalid_penality"] = 0

survey = LowVisiblitySurvey(
observatory_config=obs_config,
survey_config=survey_config,
required_sites=required_sites,
other_site_weight=0,
)
actions = [
{"time": [time], "location": {"ra": [ra], "decl": [decl]}}
for ra, decl, time in zip([0, 10], [0, 10], [59948, 59948.1])
]
for action in actions:
survey.step(action)

assert survey._subclass_reward() < expected_reward


def test_lowvis_hit_required_sites_incorrect_band():
required_sites = [
{"band": band, "location": {"ra": [ra], "decl": [decl]}}
for ra, decl, band in zip([0, 10], [0, 10], ["g", "g"])
]

expected_reward = 1
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}
survey_config = ReadConfig(survey=True)()
survey_config["invalid_penality"] = 0

survey = LowVisiblitySurvey(
observatory_config=obs_config,
survey_config=survey_config,
required_sites=required_sites,
other_site_weight=0,
)
actions = [
{"band": band, "location": {"ra": [ra], "decl": [decl]}}
for ra, decl, band in zip([0, 10], [0, 10], ["b", "v"])
]

for action in actions:
survey.step(action)

assert survey._subclass_reward() < expected_reward


def test_lowvis_incorrect_sites():
required_sites = [
{"location": {"ra": [ra], "decl": [decl]}}
for ra, decl in zip([15, 20], [10, 20])
]

expected_reward = 1
obs_config = ReadConfig(survey=False)()
obs_config["location"] = {"ra": [0], "decl": [0]}

survey_config = ReadConfig(survey=True)()
survey_config["invalid_penality"] = 0

survey = LowVisiblitySurvey(
observatory_config=obs_config,
survey_config=survey_config,
required_sites=required_sites,
other_site_weight=0,
)
actions = [
{"location": {"ra": [ra], "decl": [decl]}} for ra, decl in zip([0, 10], [0, 10])
]

for action in actions:
survey.step(action)

assert survey._subclass_reward() < expected_reward

0 comments on commit 0846e87

Please sign in to comment.