-
Notifications
You must be signed in to change notification settings - Fork 935
API proposal for v1
DoWhy's current API was designed for the effect estimation task. As the library expands to more tasks, we revisit some of the design decisions and propose an updated API. In particular, the most significant change is moving from an object-oriented API to a functional API. Beyond this change, our goal is to retain the current effect estimation API with exactly the same input-output signature while adding new API functions to include other tasks.
As you can see below, we envision two ways of achieving the each task: 1) the task-specific API, and 2) using a Graphical Causal Model (GCM). Having access to a fitted GCM simplifies the computation of almost every task, but it requires knowledge of the full causal graph. Hence if the full graph is known, we suggest using the GCM API. For most other tasks, the common API can be used.
Welcome your contributions and feedback. You can join in the discussion page here.
DoWhy started out by using a CausalModel class for the effect estimation task, and the current API provides an easy four-step process.
m = CausalModel(data, graph, ..)
identified_estimand = m.identify_effect()
estimate = m.estimate_effect(identified_estimand,
method="dowhy.propensity_score_matching")
refute = m.refute_estimate(identified_estimand, estimate,
method="placebo_treatment_refuter")
However, different tasks may need to use the CausalModel
differently. Further, object-oriented methods do not explicitly mention all the inputs required (e.g., the identification method does not require access to data).
To make the function arguments explicit and avoid book-keeping inside the CausalModel class for different tasks, we propose switching to a functional approach. The idea is that the same end-user code will work in the new API, often just by replacing CausalModel.method
with dowhy.method([data,graph])
where one of the two parameters may be optional. Note that individual identifiers, estimators, and refuters still remain as classes. It is only the user-facing functions that are being changed to a functional API.
In addition, we would like to enable MyPy types for the DoWhy API.
So there are two goals: 1) move to a functional API; 2) add mypy types in the API.
# Ia. Create a networkx DiGraph as causal graph
# This can be later expanded to support ADMGs or other mixed graphs. To support this flexibility, we should use a Protocol class in Python. We can use the DirectedGraph class in https://github.com/py-why/dowhy/blob/main/dowhy/gcm/graph.py.
causal_graph = networkx.DiGraph(...)
# Ia. [Alternative] uses causal discovery
# user can edit the graph after visualizing it, using standard networkx operations
...
# Ib. The user can validate whether the constraints from the graph are satisfied by data.
# This functionality is currently in https://github.com/py-why/dowhy/blob/main/dowhy/causal_refuters/graph_refuter.py. We need to write a refute_graph function that wraps around that class.
validate = dowhy.refute_graph(causal_graph, data)
# II. Identify causal effect and return target estimands. This should be an easy refactor on top of the current causalmodel.identify_effect()
# Note that it can accept an Iterable contains multiple action nodes too.
estimand = dowhy.identify_effect(causal_graph
action_node="X", # we use "action" because it is a more general term
outcome_node="Y",
observed_nodes=...)
# III. Fit and estimate the target estimand
# estimate = dowhy.estimate_effect(data, estimand, action_value=..., control_value=...) # method is an optional parameter here (default is "auto", meaning that we choose the best estimator based on the estimand. Internally, for a given estimation method EstMethod, this will initialize the class constructor, fit the model and return the estimates: EstMethod(estimand).fit(data).estimate_effect(action_value=..., control_value=...).
# If the user knows the method, they can provide. Note that the method_params are now explicitly provided by user in the LinearRegressionEstimator constructor.
estimate = dowhy.estimate_effect(data, estimand, method=dowhy.LinearRegressionEstimator(, ..), action_value=..., control_value=...)
# Or, directly calling the causal estimator class (for advanced users)
estimator = dowhy.LinearRegressionEstimator(estimand)
estimator.fit(data)
# IIIb. Estimate the target estimand using a statistical method.
estimate = estimator.effect(action_value=..., control_value=...)
# IIIb. GCM Alternative. Note that we can also use GCM model to estimate the target estimand.
scm = dowhy.fit_scm(causal_graph, data, estimand=estimand) # the estimand variable helps to determine the nodes to estimate. For example, there is no need to estimate nodes downstream from Y outcome
# do operation produces interventional samples
Y1 = dowhy.do(scm, estimand, input_values=[1])
Y0 = dowhy.do(scm, estimand, input_values=[0])
estimate = np.mean(Y1) – np.mean(Y0)
# IV. Refute the obtained estimate using multiple robustness checks. the below call automatically chooses the refutation tests applicable.
dowhy.refute_estimate(estimate, estimand, data)
# or user can specify the refutation test. Note that the method_params now go inside the constructor for AddRandomCommonCause
# Also the method argument is a list because typically people want to run multiple refutation tests.
dowhy.refute_estimate(estimate, estimand, data, method=[dowhy.AddUnobservedCommonCause(..)])
# Alternatively, people can call the refuter directly.
dowhy.AddUnobservedCommonCause(estimate, estimand, data, ...).refute_estimate()
# I. Create a networkx DiGraph as causal graph
causal_graph = networkx.DiGraph(...)
# II-pre. (optional) Assign causal models to nodes
dowhy.set_causal_model(causal_graph, "X", MyConditionalStochasticModel())
# II. Auto-assign causal models to nodes and fit them to the data
scm = dowhy.fit_scm(causal_graph, data)
# III. Loop over different pairs of action_nodes and outcome_nodes to identify, estimate and refute causal effects
estimand = dowhy.identify_effect(causal_graph
action_node=["X1", “X2”],
outcome_node="Y",
observed_nodes=...)
Y1 = dowhy.do(scm, estimand, input_values=[1])
Y0 = dowhy.do(scm, estimand, input_values=[0])
estimate = np.mean(Y1) – np.mean(Y0)
# Refute estimate (same as before)
dowhy.refute_estimate(estimate, estimand, data,
method=RandomCommonCause(..))
We add an extra arguments to identify effect for CATE estimation.
estimand = dowhy.identify_effect(
causal_graph,
action_node,
outcome_node,
conditioning_node=null # for heterogeneous treatment effect or interaction analysis; supports array of nodes
# (i.e., pseudo-effect modifier or effect modifier)
If conditioning_node
is null, then there is no CATE estimation
If conditioning_node
is non-null, then we are analyzing
If conditioning_node
is "*"
then:
-
conditioning_node
will be determined by (a) the potential effect modifiers as determined by the causal graph; and (b) the effect inference algorithm.
Note:
- identification happens as existing.
- nodes in conditioning_node are tested for being ancestors of the outcome and action nodes
- throw error if they are descendants of outcome * warn if they are ancestors of action nodes and d-separated from outcome by action nodes * warn if they are descendants of action nodes but not descendants of any other nodes * While the warnings above are valid questions to ask, unless the causal graph is missing an edge or node, there should little heterogeneity in this case (?)
- If we ever add an ability to assume that the interactions are additive, then we should consider throwing a warning when running a CATE estimation
Note: if the intention is to run a causal interaction analysis---where our intention is to analyze action_node
parameter instead of the conditioning_node
parameter.
estimator = dowhy.LinearRegressionEstimator(estimand)
estimator.fit(data)
estimate = estimator.estimate_effect(
action_value = ...,
control_value = ...,
conditioning_value = ...)
If the conditioning_value
is null, then the estimand's conditioning_node must have been null as well.
If the conditioning_value
is non-null, then the estimand's conditioning_node must have been non-null as well.
If conditioning_value
is *
then the estimation algorithm will return heterogeneous treatment effects over algorithmically-determined groups. Some estimation methods might not support this.
If conditioning_value
is a single set of values, then estimate_effect will return the treatment effect over the specified group. This set of values must have the dimensionality of conditioning_{nodes}
as specified in the identify_effect function. If any value is given as *
then the estimation algorithm will return heterogeneous treatment effects over algorithmically determined subgroups. If any value is given as a all
then the estimation algorithm will return a heterogeneous treatment effect averaged over all values of that feature. The returned estimate object includes a field identifying the subgroup. If any value is a *
, then the returned value will be an array of estimates (since the algorithm may split the given value into algorithmically determined subgroups)
If conditioning_values
is a collection of value-sets, then estimate_effect will return treatment effect over each of the given groups, as described above. The returned estimate object is an array of estimates, each including a field identifying the subgroup. The returned array may not include elements for every requested group, and may include elements for groups that were not requested.
causal_graph = networkx.DiGraph(...)
When modeling panel data, our causal graph must include indexed nodes. Any node indexed in the causal graph must be indexed relative to a running variable. I.e., nodes may not have absolute indices or be indexed by other features in the graph. They may only be indexed by a running variable, plus or minus a constant. E.g., if
There may be multiple running variables, e.g., to accomodate spatial and temporal data. Note that graphs, of course, must still be DAGs, etc.
Note: Zeroth implementations, while we are working out the details of DAG representations, may simply be "dummy" graphs that are tagged to support specific algorithms. See discussion below on implementation notes.
Effect identification is not, in principle, changed. However, identification algorithms must be modified to handle indexed nodes. The simplest approach may be for dowhy to unroll the graph
Conceptually, the data behind a causal graph includes both unindexed data (i.e., constant data that is not indexed by any running variable); data indexed by some but not all running variables; and data that is indexed by all running variables.
For example, consider a causal graph governed by 3 running variables
Estimate effect takes a dataframe that is either jagged, long-form, or sparse.
-
Long-form dataframe - a long-form dataframe creates a representation of multi-dimensional data that duplicates non-indexed or partially-indexed variables. E.g., given a constant feature
$A$ and a time index$t$ , then$A_t$ will be set to the constant$A$ for all values of the index$t$ . -
Jagged dataframe
-
Sparse dataframe - Like a long-form dataframe, but all index values must be columns in the dataframe as well.
Q: should be more explicit about the data type: MultiIndex?
Q: is a unit index just another running variable? maybe there's a reserved index i for the unit index. For backwards compatibility, we can assume that a causal graph that has no explicit running indices has a unit index of i under each causal graph. that should give us consistency of representations.
Panel data estimation methods that return only LATE estimates (e.g., regression discontinuities) will return results that are marked with the subgroup for which the estimate is valid like CATE estimation proposed above.
For the experimental or 0th implementation, I expect that the we will not have support for analyzing arbitrary indexed causal graphs. Rather, we will provide helper functions to create specific causal graphs for 2-3 key panel data analysis methods. These helper functions will create easily recognizable indexed causal graphs, and all other indexed causal graphs will throw a not-yet-implemented error.
The purpose of this is to keep the API consistent across effect estimation methods, and then to get the basic algorithms working. Once we have the basic algorithms working, we can then iteratively improve the graphical and non-graphical representations of necessary assumptions in the causal model, and automation or generalization of identification strategies.
I expect long-form dataframes to be the initial implementation of data representation for most algorithms. Jagged and/or sparse data frames will not be supported or will be converted to long-form dataframes initially.
I expect the following panel data algorithms will be supported initially:
- EconML time series analyses (e.g., dynamic treatment analysis)
- difference-in-difference
- synthetic controls
- synthetic difference-in-difference
- sparse synthetic controls (SparseSC)
We may add support unrolling of certain time series data for analysis via non-panel analysis methods.
We've also received requests for survival analysis methods over panel data.
In the below examples, I assume the addition of a helper function to create the initial causal graph consistent with the selected estimator.
# The SyntheticControlsGraph helper function creates a causal graph representation that captures the assumptions necessary for running a synthetic control analysis.
# (note: non-graphical assumptions should also included explicitly. first implementation might capture with a simple tag on the causal graph while we
# develop a more sophisticated approach to capturing non-graphical assumptions)
causalgraph = dowhy.panelanalysis.SyntheticControlsGraph(features = ['A','B','C'],
outcomes = ['Y'],
treatment_status = 'T',
running_index = 't')
estimand = dowhy.identify_effect(causalgraph,
action_node = ['T'],
outcome_node = ['Y'],
# load panel data (long-form dataframe)
panel_data = pd.read_csv('panel_data.csv')
estimator = dowhy.SyntheticControlsEstimator(estimand))
# The SyntheticControlsEstimator's fit method should validate non-graphical assumptions on the data. E.g., that once units are treated, they remain treated.
estimator.fit(paneldata)
estimate = estimator.estimate_effect(action_value=..., control_value=...)
# Refute the obtained estimate using multiple robustness checks.
dowhy.refute_estimate(estimate, estimand, paneldata, causal_graph,
method_name="random_common_cause")
# The Difference-in-Differences Graph helper function creates a causal graph representation that captures the assumptions necessary for running a diff-in-diff analysis.
causalgraph = dowhy.panelanalysis.DiffInDiffGraph(outcomes = ['Y'],
treatment_status = 'T',
running_index = 't')
estimand = dowhy.identify_effect(causalgraph,
action_node = ['T'],
outcome_node = ['Y'],
# load panel data (long-form dataframe)
panel_data = pd.read_csv('panel_data.csv')
estimator = dowhy.SyntheticControlsEstimator(estimand))
# the SyntheticControlsEstimator's fit method should validate non-graphical assumptions on the data. E.g., that once units are treated, they remain treated.
estimator.fit(paneldata)
estimate = estimator.estimate_effect(action_value=..., control_value=...)
# Refute the obtained estimate using multiple robustness checks.
dowhy.refute_estimate(estimate, estimand, paneldata, causal_graph,
method_name="random_common_cause")
See API Proposal for Causal Prediction
Our current plans include GCM-based attribution. Looking for suggestions on non-GCM based methods for causal attribution.
# Fit an SCM as before
# Analyze root causes from anomalous data.
anomaly_scores = dowhy.anomaly_scores(scm, potentially_anomalous_data)
# Perform interventional analysis on particular node. Both soft and hard(do) intervention supported.
intervention_samples = dowhy.intervene(scm, target_node="X", intervention_func)
# Attribute distribution changes
change_attributions = dowhy.distribution_change(scm, data, new_data)
Counterfactuals do require access to a fitted GCM.
# assume causal graph is provided
scm = dowhy.fit_scm(causal_graph, data)
identified_cf = dowhy.identify_counterfactual(scm, evidence, action, outcome)
cf = dowhy.estimate_counterfactual(identified_cf, data)
dowhy.refute_counterfactual(cf, model=scm)