Skip to content

Commit

Permalink
Merge pull request #269 from CITCOM-project/interaction-terms
Browse files Browse the repository at this point in the history
Temporary workaround for "I(...) not in df" bug
  • Loading branch information
jmafoster1 authored Mar 18, 2024
2 parents 5b9f113 + 4d78ae9 commit 0cb5492
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 19 deletions.
4 changes: 2 additions & 2 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ class CausalDAG(nx.DiGraph):
def __init__(self, dot_path: str = None, **attr):
super().__init__(**attr)
if dot_path:
with open(dot_path, 'r', encoding='utf-8') as file:
dot_content = file.read().replace('\n', '')
with open(dot_path, "r", encoding="utf-8") as file:
dot_content = file.read().replace("\n", "")
# Previously, we used pydot_graph_from_file() to read in the dot_path directly, however,
# this method does not currently have a way of removing spurious nodes.
# Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph.
Expand Down
2 changes: 1 addition & 1 deletion causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SimulationResult:
relationship: str


class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
space to be searched"""

Expand Down
7 changes: 4 additions & 3 deletions causal_testing/surrogate/surrogate_search_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:

# pylint: disable=too-many-locals
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
solutions = []

Expand All @@ -47,7 +47,8 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
ate = surrogate.estimate_ate_calculated(adjustment_dict)
if len(ate) > 1:
raise ValueError(
"Multiple ate values provided but currently only single values supported in this method")
"Multiple ate values provided but currently only single values supported in this method"
)
return contradiction_function(ate[0])

gene_types, gene_space = self.create_gene_types(surrogate, specification)
Expand Down Expand Up @@ -84,7 +85,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument

@staticmethod
def create_gene_types(
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
) -> tuple[list, list]:
"""Generate the gene_types and gene_space for a given fitness function and specification
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
Expand Down
28 changes: 16 additions & 12 deletions causal_testing/testing/causal_test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ class SomeEffect(CausalTestOutcome):
def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
return any(
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
if res.test_value.type in ('coefficient', 'ate'):
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
)
if res.test_value.type in ("coefficient", "ate"):
return any(
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
)

raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

Expand All @@ -51,17 +53,19 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):

def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
zip(res.ci_low(), res.ci_high(), res.test_value.value))
if res.test_value.type in ('coefficient', 'ate'):
return any(
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)
)
if res.test_value.type in ("coefficient", "ate"):
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
return (
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
)
/ len(value)
< self.ctol
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
)
/ len(value)
< self.ctol
)

raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
Expand Down
9 changes: 8 additions & 1 deletion causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,14 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
model = self._run_linear_regression()
newline = "\n"
patsy_md = ModelDesc.from_formula(self.treatment)
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
if any(
(
self.df.dtypes[factor.name()] == "object"
for factor in patsy_md.rhs_termlist[1].factors
# We want to remove this long term as it prevents us from discovering categoricals within I(...) blocks
if factor.name() in self.df.dtypes
)
):
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
else:
Expand Down

0 comments on commit 0cb5492

Please sign in to comment.