Skip to content

Commit 72c3997

Browse files
Merge pull request #179 from CITCOM-project/JSON_treatment_var
JSON concrete tests
2 parents 106f346 + 2b919ed commit 72c3997

File tree

7 files changed

+152
-80
lines changed

7 files changed

+152
-80
lines changed

causal_testing/json_front/json_class.py

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from causal_testing.specification.causal_specification import CausalSpecification
2121
from causal_testing.specification.scenario import Scenario
2222
from causal_testing.specification.variable import Input, Meta, Output
23+
from causal_testing.testing.base_test_case import BaseTestCase
2324
from causal_testing.testing.causal_test_case import CausalTestCase
2425
from causal_testing.testing.causal_test_engine import CausalTestEngine
2526
from causal_testing.testing.estimators import Estimator
@@ -46,7 +47,7 @@ class JsonUtility:
4647

4748
def __init__(self, output_path: str, output_overwrite: bool = False):
4849
self.input_paths = None
49-
self.variables = None
50+
self.variables = {"inputs": {}, "outputs": {}, "metas": {}}
5051
self.data = []
5152
self.test_plan = None
5253
self.scenario = None
@@ -66,13 +67,71 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
6667
def setup(self, scenario: Scenario):
6768
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
6869
self.scenario = scenario
70+
self._get_scenario_variables()
6971
self.scenario.setup_treatment_variables()
7072
self.causal_specification = CausalSpecification(
7173
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
7274
)
7375
self._json_parse()
7476
self._populate_metas()
7577

78+
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
79+
"""Runs and evaluates each test case specified in the JSON input
80+
81+
:param effects: Dictionary mapping effect class instances to string representations.
82+
:param mutates: Dictionary mapping mutation functions to string representations.
83+
:param estimators: Dictionary mapping estimator classes to string representations.
84+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
85+
"""
86+
failures = 0
87+
for test in self.test_plan["tests"]:
88+
if "skip" in test and test["skip"]:
89+
continue
90+
test["estimator"] = estimators[test["estimator"]]
91+
if "mutations" in test:
92+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
93+
94+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
95+
failures = self._execute_tests(concrete_tests, test, f_flag)
96+
msg = (
97+
f"Executing test: {test['name']}\n"
98+
+ "abstract_test\n"
99+
+ f"{abstract_test}\n"
100+
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution}\n"
101+
+ f"Number of concrete tests for test case: {str(len(concrete_tests))}\n"
102+
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
103+
)
104+
self._append_to_file(msg, logging.INFO)
105+
else:
106+
outcome_variable = next(
107+
iter(test["expected_effect"])
108+
) # Take first key from dictionary of expected effect
109+
base_test_case = BaseTestCase(
110+
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
111+
outcome_variable=self.variables["outputs"][outcome_variable],
112+
)
113+
114+
causal_test_case = CausalTestCase(
115+
base_test_case=base_test_case,
116+
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
117+
control_value=test["control_value"],
118+
treatment_value=test["treatment_value"],
119+
estimate_type=test["estimate_type"],
120+
)
121+
if self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag):
122+
result = "failed"
123+
else:
124+
result = "passed"
125+
126+
msg = (
127+
f"Executing concrete test: {test['name']} \n"
128+
+ f"treatment variable: {test['treatment_variable']} \n"
129+
+ f"outcome_variable = {outcome_variable} \n"
130+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
131+
+ f"result - {result}"
132+
)
133+
self._append_to_file(msg, logging.INFO)
134+
76135
def _create_abstract_test_case(self, test, mutates, effects):
77136
assert len(test["mutations"]) == 1
78137
abstract_test = AbstractCausalTestCase(
@@ -81,7 +140,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
81140
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
82141
expected_causal_effect={
83142
self.scenario.variables[variable]: effects[effect]
84-
for variable, effect in test["expectedEffect"].items()
143+
for variable, effect in test["expected_effect"].items()
85144
},
86145
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
87146
if "effect_modifiers" in test
@@ -91,35 +150,8 @@ def _create_abstract_test_case(self, test, mutates, effects):
91150
)
92151
return abstract_test
93152

94-
def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
95-
"""Runs and evaluates each test case specified in the JSON input
96-
97-
:param effects: Dictionary mapping effect class instances to string representations.
98-
:param mutates: Dictionary mapping mutation functions to string representations.
99-
:param estimators: Dictionary mapping estimator classes to string representations.
100-
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
101-
"""
153+
def _execute_tests(self, concrete_tests, test, f_flag):
102154
failures = 0
103-
for test in self.test_plan["tests"]:
104-
if "skip" in test and test["skip"]:
105-
continue
106-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
107-
108-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
109-
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
110-
msg = (
111-
f"Executing test: {test['name']} \n"
112-
+ "abstract_test \n"
113-
+ f"{abstract_test} \n"
114-
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n"
115-
+ f"Number of concrete tests for test case: {str(len(concrete_tests))} \n"
116-
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
117-
)
118-
self._append_to_file(msg, logging.INFO)
119-
120-
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
121-
failures = 0
122-
test["estimator"] = estimators[test["estimator"]]
123155
if "formula" in test:
124156
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
125157
for concrete_test in concrete_tests:
@@ -161,15 +193,13 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
161193
:rtype: bool
162194
"""
163195
failed = False
164-
165196
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
166197
causal_test_result = causal_test_engine.execute_test(
167198
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
168199
)
169200

170201
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
171202

172-
result_string = str()
173203
if causal_test_result.ci_low() and causal_test_result.ci_high():
174204
result_string = (
175205
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
@@ -214,7 +244,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
214244
}
215245
if "formula" in test:
216246
estimator_kwargs["formula"] = test["formula"]
217-
218247
estimation_model = test["estimator"](**estimator_kwargs)
219248
return causal_test_engine, estimation_model
220249

@@ -226,10 +255,18 @@ def _append_to_file(self, line: str, log_level: int = None):
226255
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
227256
"""
228257
with open(self.output_path, "a", encoding="utf-8") as f:
229-
f.write(line)
258+
f.write(line + "\n")
230259
if log_level:
231260
logger.log(level=log_level, msg=line)
232261

262+
def _get_scenario_variables(self):
263+
for input_var in self.scenario.inputs():
264+
self.variables["inputs"][input_var.name] = input_var
265+
for output_var in self.scenario.outputs():
266+
self.variables["outputs"][output_var.name] = output_var
267+
for meta_var in self.scenario.metas():
268+
self.variables["metas"][meta_var.name] = meta_var
269+
233270
@staticmethod
234271
def check_file_exists(output_path: Path, overwrite: bool):
235272
"""Method that checks if the given path to an output file already exists. If overwrite is true the check is

docs/source/frontends/json_front_end.rst

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,28 @@ Use case specific information is also declared here such as the paths to the rel
2121

2222
causal_tests.json
2323
-----------------
24-
`examples/poisson/causal_tests.json <https://github.com/CITCOM-project/CausalTestingFramework/blob/main/examples/poisson/causal_tests.json>`_ contains python code written by the user to implement scenario specific features
25-
is the JSON file that allows for the easy specification of multiple causal tests.
24+
`examples/poisson/causal_tests.json <https://github.c#om/CITCOM-project/CausalTestingFramework/blob/main/examples/poisson/causal_tests.json>`_ contains python code written by the user to implement scenario specific features
25+
is the JSON file that allows for the easy specification of multiple causal tests. Tests can be specified two ways; firstly by specifying a mutation lke in the example tests with the following structure:
2626
Each test requires:
27-
1. Test name
28-
2. Mutations
29-
3. Estimator
30-
4. Estimate_type
31-
5. Effect modifiers
32-
6. Expected effects
33-
7. Skip: boolean that if set true the test won't be executed and will be skipped
3427

28+
#. name
29+
#. mutations
30+
#. estimator
31+
#. estimate_type
32+
#. effect_modifiers
33+
#. expected_effects
34+
#. skip: boolean that if set true the test won't be executed and will be skipped
35+
36+
The second method of specifying a test is to specify the test in a concrete form with the following structure:
37+
38+
#. name
39+
#. treatment_variable
40+
#. control_value
41+
#. treatment_value
42+
#. estimator
43+
#. estimate_type
44+
#. expected_effect
45+
#. skip
3546

3647
Run Commands
3748
------------

examples/poisson/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ To run this case study:
66
1. Ensure all project dependencies are installed by running `pip install .` in the top level directory
77
(instructions are provided in the project README).
88
2. Change directory to `causal_testing/examples/poisson`.
9-
3. Run the command `python test_run_causal_tests.py --data_path data.csv --dag_path dag.dot --json_path causal_tests.json`
9+
3. Run the command `python example_run_causal_tests.py --data_path data.csv --dag_path dag.dot --json_path causal_tests.json`
1010

1111
This should print a series of causal test results and produce two CSV files. `intensity_num_shapes_results_random_1000.csv` corresponds to table 1, and `width_num_shapes_results_random_1000.csv` relates to our findings regarding the relationship of width and `P_u`.

0 commit comments

Comments
 (0)