Skip to content

Commit

Permalink
Updated docs, tests and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
rsomers1998 committed Jun 23, 2023
1 parent 3d3d3b4 commit 45cb719
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 18 deletions.
3 changes: 1 addition & 2 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ various information. Here, we simply assert that the observed result is (on aver
causal_test_result = causal_test_engine.execute_test(
estimator = estimation_model,
causal_test_case = causal_test_case,
estimate_type = "ate")
causal_test_case = causal_test_case)
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
assert test_passes, "Expected to see a positive change in y."
Expand Down
6 changes: 3 additions & 3 deletions examples/covasim_/doubling_beta/example_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def doubling_beta_CATE_on_csv(
)

# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)

# Repeat for association estimate (no adjustment)
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
Expand All @@ -78,7 +78,7 @@ def doubling_beta_CATE_on_csv(
formula="cum_infections ~ beta + np.power(beta, 2)",
)
association_test_result = causal_test_engine.execute_test(
no_adjustment_linear_regression_estimator, causal_test_case, "ate"
no_adjustment_linear_regression_estimator, causal_test_case
)

# Store results for plotting
Expand Down Expand Up @@ -110,7 +110,7 @@ def doubling_beta_CATE_on_csv(
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
)
counterfactual_causal_test_result = causal_test_engine.execute_test(
linear_regression_estimator, causal_test_case, "ate"
linear_regression_estimator, causal_test_case
)
results_dict["counterfactual"] = {
"ate": counterfactual_causal_test_result.test_value.value,
Expand Down
2 changes: 1 addition & 1 deletion examples/covasim_/vaccinating_elderly/example_vaccine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
)

# 10. Execute test and save results in dict
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
if verbose:
logging.info("Causation:\n%s", causal_test_result)
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value
Expand Down
2 changes: 1 addition & 1 deletion examples/lr91/example_max_conductances.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
)

# 10. Run the causal test and print results
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
logger.info("%s", causal_test_result)
return causal_test_result.test_value.value, causal_test_result.confidence_intervals

Expand Down
2 changes: 1 addition & 1 deletion examples/poisson-line-process/example_poisson_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def causal_test_intensity_num_shapes(
)

# 10. Execute the test
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case, causal_test_case.estimate_type)
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case)

return causal_test_result

Expand Down
18 changes: 8 additions & 10 deletions tests/testing_tests/test_causal_test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
"A",
self.causal_test_engine.scenario_execution_data_df,
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="coefficient"
)
self.causal_test_case.estimate_type = "coefficient"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
self.assertEqual(int(causal_test_result.test_value.value), 0)

def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
Expand All @@ -205,9 +204,8 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
"A",
self.causal_test_engine.scenario_execution_data_df,
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="risk_ratio"
)
self.causal_test_case.estimate_type = "risk_ratio"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
self.assertEqual(int(causal_test_result.test_value.value), 0)

def test_invalid_estimate_type(self):
Expand All @@ -221,8 +219,9 @@ def test_invalid_estimate_type(self):
"A",
self.causal_test_engine.scenario_execution_data_df,
)
self.causal_test_case.estimate_type = "invalid"
with self.assertRaises(ValueError):
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case, estimate_type="invalid")
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)

def test_execute_test_observational_linear_regression_estimator_squared_term(self):
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
Expand Down Expand Up @@ -258,9 +257,8 @@ def test_execute_observational_causal_forest_estimator_cates(self):
self.causal_test_engine.scenario_execution_data_df,
effect_modifiers={"M": None},
)
causal_test_result = self.causal_test_engine.execute_test(
estimation_model, self.causal_test_case, estimate_type="cate"
)
self.causal_test_case.estimate_type = "cate"
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
causal_test_result = causal_test_result.test_value.value
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]
Expand Down

0 comments on commit 45cb719

Please sign in to comment.