Skip to content

Commit

Permalink
Merge branch 'main' into json_optional_data_path
Browse files Browse the repository at this point in the history
  • Loading branch information
christopher-wild authored Jun 26, 2023
2 parents be0c7ad + aa90884 commit e390c3e
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 36 deletions.
4 changes: 1 addition & 3 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def _execute_test_case(
causal_test_engine, estimation_model = self._setup_test(
causal_test_case, test, test["conditions"] if "conditions" in test else None
)
causal_test_result = causal_test_engine.execute_test(
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
)
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)

test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)

Expand Down
27 changes: 12 additions & 15 deletions causal_testing/testing/causal_test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu

estimators = test_suite[edge]["estimators"]
tests = test_suite[edge]["tests"]
estimate_type = test_suite[edge]["estimate_type"]
results = {}
for estimator_class in estimators:
causal_test_results = []
Expand All @@ -96,16 +95,14 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
)
if estimator.df is None:
estimator.df = self.scenario_execution_data_df
causal_test_result = self._return_causal_test_results(estimate_type, estimator, test)
causal_test_result = self._return_causal_test_results(estimator, test)
causal_test_results.append(causal_test_result)

results[estimator_class.__name__] = causal_test_results
test_suite_results[edge] = results
return test_suite_results

def execute_test(
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
) -> CausalTestResult:
def execute_test(self, estimator: type(Estimator), causal_test_case: CausalTestCase) -> CausalTestResult:
"""Execute a causal test case and return the causal test result.
Test case execution proceeds with the following steps:
Expand All @@ -120,7 +117,6 @@ def execute_test(
:param estimator: A reference to an Estimator class.
:param causal_test_case: The CausalTestCase object to be tested
:param estimate_type: A string which denotes the type of estimate to return, ATE or CATE.
:return causal_test_result: A CausalTestResult for the executed causal test case.
"""
if self.scenario_execution_data_df.empty:
Expand All @@ -142,18 +138,17 @@ def execute_test(
if self._check_positivity_violation(variables_for_positivity):
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")

causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
causal_test_result = self._return_causal_test_results(estimator, causal_test_case)
return causal_test_result

def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
def _return_causal_test_results(self, estimator, causal_test_case):
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
:param estimate_type: A string which denotes the type of estimate to return
:param estimator: An Estimator class object
:param causal_test_case: The concrete test case to be executed
:return: a CausalTestResult object containing the confidence intervals
"""
if estimate_type == "cate":
if causal_test_case.estimate_type == "cate":
logger.debug("calculating cate")
if not hasattr(estimator, "estimate_cates"):
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
Expand All @@ -165,7 +160,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
)
elif estimate_type == "risk_ratio":
elif causal_test_case.estimate_type == "risk_ratio":
logger.debug("calculating risk_ratio")
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
causal_test_result = CausalTestResult(
Expand All @@ -174,7 +169,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
)
elif estimate_type == "coefficient":
elif causal_test_case.estimate_type == "coefficient":
logger.debug("calculating coefficient")
coefficient, confidence_intervals = estimator.estimate_unit_ate()
causal_test_result = CausalTestResult(
Expand All @@ -183,7 +178,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
)
elif estimate_type == "ate":
elif causal_test_case.estimate_type == "ate":
logger.debug("calculating ate")
ate, confidence_intervals = estimator.estimate_ate()
causal_test_result = CausalTestResult(
Expand All @@ -194,7 +189,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
)
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
elif estimate_type == "ate_calculated":
elif causal_test_case.estimate_type == "ate_calculated":
logger.debug("calculating ate")
ate, confidence_intervals = estimator.estimate_ate_calculated()
causal_test_result = CausalTestResult(
Expand All @@ -206,7 +201,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
else:
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
raise ValueError(
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
)
return causal_test_result

def _check_positivity_violation(self, variables_list):
Expand Down
3 changes: 1 addition & 2 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ various information. Here, we simply assert that the observed result is (on aver
causal_test_result = causal_test_engine.execute_test(
estimator = estimation_model,
causal_test_case = causal_test_case,
estimate_type = "ate")
causal_test_case = causal_test_case)
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
assert test_passes, "Expected to see a positive change in y."
Expand Down
6 changes: 3 additions & 3 deletions examples/covasim_/doubling_beta/example_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def doubling_beta_CATE_on_csv(
)

# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)

# Repeat for association estimate (no adjustment)
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
Expand All @@ -78,7 +78,7 @@ def doubling_beta_CATE_on_csv(
formula="cum_infections ~ beta + np.power(beta, 2)",
)
association_test_result = causal_test_engine.execute_test(
no_adjustment_linear_regression_estimator, causal_test_case, "ate"
no_adjustment_linear_regression_estimator, causal_test_case
)

# Store results for plotting
Expand Down Expand Up @@ -110,7 +110,7 @@ def doubling_beta_CATE_on_csv(
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
)
counterfactual_causal_test_result = causal_test_engine.execute_test(
linear_regression_estimator, causal_test_case, "ate"
linear_regression_estimator, causal_test_case
)
results_dict["counterfactual"] = {
"ate": counterfactual_causal_test_result.test_value.value,
Expand Down
2 changes: 1 addition & 1 deletion examples/covasim_/vaccinating_elderly/example_vaccine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
)

# 10. Execute test and save results in dict
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
if verbose:
logging.info("Causation:\n%s", causal_test_result)
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value
Expand Down
2 changes: 1 addition & 1 deletion examples/lr91/example_max_conductances.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
)

# 10. Run the causal test and print results
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
logger.info("%s", causal_test_result)
return causal_test_result.test_value.value, causal_test_result.confidence_intervals

Expand Down
2 changes: 1 addition & 1 deletion examples/poisson-line-process/example_poisson_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def causal_test_intensity_num_shapes(
)

# 10. Execute the test
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case, causal_test_case.estimate_type)
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case)

return causal_test_result

Expand Down
18 changes: 8 additions & 10 deletions tests/testing_tests/test_causal_test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
"A",
self.causal_test_engine.scenario_execution_data_df,
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="coefficient"
)
self.causal_test_case.estimate_type = "coefficient"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
self.assertEqual(int(causal_test_result.test_value.value), 0)

def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
Expand All @@ -205,9 +204,8 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
"A",
self.causal_test_engine.scenario_execution_data_df,
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="risk_ratio"
)
self.causal_test_case.estimate_type = "risk_ratio"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
self.assertEqual(int(causal_test_result.test_value.value), 0)

def test_invalid_estimate_type(self):
Expand All @@ -221,8 +219,9 @@ def test_invalid_estimate_type(self):
"A",
self.causal_test_engine.scenario_execution_data_df,
)
self.causal_test_case.estimate_type = "invalid"
with self.assertRaises(ValueError):
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case, estimate_type="invalid")
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)

def test_execute_test_observational_linear_regression_estimator_squared_term(self):
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
Expand Down Expand Up @@ -258,9 +257,8 @@ def test_execute_observational_causal_forest_estimator_cates(self):
self.causal_test_engine.scenario_execution_data_df,
effect_modifiers={"M": None},
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="cate"
)
self.causal_test_case.estimate_type = "cate"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
causal_test_result = causal_test_result.test_value.value
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]
Expand Down

0 comments on commit e390c3e

Please sign in to comment.