Skip to content

Commit

Permalink
Merge pull request #286 from CITCOM-project/temporal-shenanigans
Browse files Browse the repository at this point in the history
Enabling causal effect estimation in the presence of time-varying confounding using IPCW.
  • Loading branch information
jmafoster1 authored Aug 2, 2024
2 parents b6ce637 + e6284c5 commit 2b7042d
Show file tree
Hide file tree
Showing 15 changed files with 593 additions and 27 deletions.
2 changes: 1 addition & 1 deletion causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _execute_test_case(
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)

if "coverage" in test and test["coverage"]:
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, self.data_collector)
adequacy_metric = DataAdequacy(causal_test_case, estimation_model)
adequacy_metric.measure_adequacy()
causal_test_result.adequacy = adequacy_metric

Expand Down
76 changes: 76 additions & 0 deletions causal_testing/specification/capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
This module contains the Capability and TreatmentSequence classes to implement
treatment sequences that operate over time.
"""
from typing import Any
from causal_testing.specification.variable import Variable


class Capability:
"""
Data class to encapsulate temporal interventions.
"""

def __init__(self, variable: Variable, value: Any, start_time: int, end_time: int):
self.variable = variable
self.value = value
self.start_time = start_time
self.end_time = end_time

def __eq__(self, other):
return (
isinstance(other, type(self))
and self.variable == other.variable
and self.value == other.value
and self.start_time == other.start_time
and self.end_time == other.end_time
)

def __repr__(self):
return f"({self.variable}, {self.value}, {self.start_time}-{self.end_time})"


class TreatmentSequence:
"""
Class to represent a list of capabilities, i.e. a treatment regime.
"""

def __init__(self, timesteps_per_intervention, capabilities):
self.timesteps_per_intervention = timesteps_per_intervention
self.capabilities = [
Capability(var, val, t, t + timesteps_per_intervention)
for (var, val), t in zip(
capabilities,
range(
timesteps_per_intervention,
(len(capabilities) * timesteps_per_intervention) + 1,
timesteps_per_intervention,
),
)
]
# This is a bodge so that causal test adequacy works
self.name = tuple(c.variable for c in self.capabilities)

def set_value(self, index: int, value: float):
"""
Set the value of capability at the given index.
:param index - the index of the element to update.
:param value - the desired value of the capability.
"""
self.capabilities[index].value = value

def copy(self):
"""
Return a deep copy of the capability list.
"""
strategy = TreatmentSequence(
self.timesteps_per_intervention,
[(c.variable, c.value) for c in self.capabilities],
)
return strategy

def total_time(self):
"""
Calculate the total duration of the treatment strategy.
"""
return (len(self.capabilities) + 1) * self.timesteps_per_intervention
49 changes: 39 additions & 10 deletions causal_testing/testing/causal_test_adequacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
This module contains code to measure various aspects of causal test adequacy.
"""

import logging
from itertools import combinations
from copy import deepcopy
import pandas as pd
from numpy.linalg import LinAlgError
from lifelines.exceptions import ConvergenceError

from causal_testing.testing.causal_test_suite import CausalTestSuite
from causal_testing.data_collection.data_collector import DataCollector
from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.testing.estimators import Estimator
from causal_testing.testing.causal_test_case import CausalTestCase

logger = logging.getLogger(__name__)


class DAGAdequacy:
"""
Expand Down Expand Up @@ -70,15 +74,21 @@ class DataAdequacy:
- Zero kurtosis is optimal.
"""

# pylint: disable=too-many-instance-attributes
def __init__(
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
self,
test_case: CausalTestCase,
estimator: Estimator,
bootstrap_size: int = 100,
group_by=None,
):
self.test_case = test_case
self.estimator = estimator
self.data_collector = data_collector
self.kurtosis = None
self.outcomes = None
self.successful = None
self.bootstrap_size = bootstrap_size
self.group_by = group_by

def measure_adequacy(self):
"""
Expand All @@ -87,11 +97,24 @@ def measure_adequacy(self):
results = []
for i in range(self.bootstrap_size):
estimator = deepcopy(self.estimator)
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
# try:
results.append(self.test_case.execute_test(estimator, self.data_collector))
# except np.LinAlgError:
# continue

if self.group_by is not None:
ids = pd.Series(estimator.df[self.group_by].unique())
ids = ids.sample(len(ids), replace=True, random_state=i)
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)]
else:
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
try:
results.append(self.test_case.execute_test(estimator, None))
except LinAlgError:
logger.warning("Adequacy LinAlgError")
continue
except ConvergenceError:
logger.warning("Adequacy ConvergenceError")
continue
except ValueError as e:
logger.warning(f"Adequacy ValueError: {e}")
continue
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]

Expand All @@ -111,8 +134,14 @@ def convert_to_df(field):

effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True)
self.kurtosis = effect_estimate.kurtosis()
self.outcomes = sum(outcomes)
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
self.successful = sum(x is not None for x in outcomes)

def to_dict(self):
"Returns the adequacy object as a dictionary."
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes}
return {
"kurtosis": self.kurtosis.to_dict(),
"bootstrap_size": self.bootstrap_size,
"passing": self.outcomes,
"successful": self.successful,
}
2 changes: 1 addition & 1 deletion causal_testing/testing/causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
except np.linalg.LinAlgError:
return CausalTestResult(
estimator=estimator,
test_value=TestValue(self.estimate_type, "LinAlgError"),
test_value=TestValue(self.estimate_type, None),
effect_modifier_configuration=self.effect_modifier_configuration,
confidence_intervals=None,
)
Expand Down
6 changes: 4 additions & 2 deletions causal_testing/testing/causal_test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class SomeEffect(CausalTestOutcome):
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""

def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
if res.ci_low() is None or res.ci_high() is None:
return None
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
return any(
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
)
Expand All @@ -52,7 +54,7 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
self.ctol = ctol

def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
return any(
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)
Expand Down
Loading

0 comments on commit 2b7042d

Please sign in to comment.