Skip to content

Commit

Permalink
updated aps case study for multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
rsomers1998 committed Dec 6, 2023
1 parent 7c5e84b commit 727fb6d
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions examples/surrogate_assisted/apsdigitaltwin/apsdigitaltwin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
from causal_testing.specification.variable import Input, Output
from causal_testing.testing.causal_surrogate_assisted import CausalSurrogateAssistedTestCase, SimulationResult, Simulator
from causal_testing.testing.surrogate_search_algorithms import GeneticSearchAlgorithm
from examples.apsdigitaltwin.util.model import Model, OpenAPS, i_label, g_label, s_label
from examples.surrogate_assisted.apsdigitaltwin.util.model import Model, OpenAPS, i_label, g_label, s_label

import pandas as pd
import numpy as np
import os
import multiprocessing as mp

import random
from dotenv import load_dotenv


class APSDigitalTwinSimulator(Simulator):
def __init__(self, constants, profile_path) -> None:
def __init__(self, constants, profile_path, output_file = "./openaps_temp") -> None:
super().__init__()

self.constants = constants
self.profile_path = profile_path
self.output_file = output_file

def run_with_config(self, configuration) -> SimulationResult:
min_bg = 200
Expand All @@ -33,7 +37,7 @@ def run_with_config(self, configuration) -> SimulationResult:
model_openaps = Model([configuration["start_cob"], 0, 0, configuration["start_bg"], configuration["start_iob"]], self.constants)
for t in range(1, 121):
if t % 5 == 1:
rate = open_aps.run(model_openaps.history, output_file=f"./openaps_temp", faulty=True)
rate = open_aps.run(model_openaps.history, output_file=self.output_file, faulty=True)
if rate == -1:
violation = True
open_aps_output += rate
Expand Down Expand Up @@ -62,8 +66,9 @@ def run_with_config(self, configuration) -> SimulationResult:

return SimulationResult(data, violation)

if __name__ == "__main__":
load_dotenv()
def main(file):
random.seed(123)
np.random.seed(123)

search_bias = Input("search_bias", float, hidden=True)

Expand Down Expand Up @@ -106,13 +111,38 @@ def run_with_config(self, configuration) -> SimulationResult:
ga_search = GeneticSearchAlgorithm(config=ga_config)

constants = []
with open("constants.txt", "r") as const_file:
const_file_name = file.replace("datasets", "constants").replace("_np_random_random_faulty_scenarios", ".txt")
with open(const_file_name, "r") as const_file:
constants = const_file.read().replace("[", "").replace("]", "").split(",")
constants = [np.float64(const) for const in constants]
constants[7] = int(constants[7])

simulator = APSDigitalTwinSimulator(constants, "./util/profile.json")
data_collector = ObservationalDataCollector(scenario, pd.read_csv("./data.csv"))
simulator = APSDigitalTwinSimulator(constants, "./util/profile.json", f"./{file}_openaps_temp")
data_collector = ObservationalDataCollector(scenario, pd.read_csv(f"./{file}.csv"))
test_case = CausalSurrogateAssistedTestCase(specification, ga_search, simulator)

print(test_case.execute(data_collector))
res, iter, df = test_case.execute(data_collector)
with open(f"./outputs/{file.replace('./datasets/', '')}.txt", "w") as out:
out.write(str(res) + "\n" + str(iter))
df.to_csv(f"./outputs/{file.replace('./datasets/', '')}_full.csv")

print(f"finished {file}")

if __name__ == "__main__":
load_dotenv()

all_traces = os.listdir("./datasets")
while len(all_traces) > 0:
num = 1
if num > len(all_traces):
num = len(all_traces)

with mp.Pool(processes=num) as pool:
pool_vals = []
while len(pool_vals) < num and len(all_traces) > 0:
data_trace = all_traces.pop()
if data_trace.endswith(".csv"):
if len(pd.read_csv(os.path.join("./datasets", data_trace))) >= 300:
pool_vals.append(f"./datasets/{data_trace[:-4]}")

pool.map(main, pool_vals)

0 comments on commit 727fb6d

Please sign in to comment.