Skip to content

API proposal for v1

Amit Sharma edited this page Sep 8, 2022 · 22 revisions

DoWhy's current API was designed for the effect estimation task. As the library expands to more tasks, we revisit some of the design decisions and propose an updated API. In particular, the most significant change is moving from an object-oriented API to a functional API. Beyond this change, our goal is to retain the current effect estimation API with exactly the same input-output signature while adding new API functions to include other tasks.

As you can see below, we envision two ways of achieving the each task: 1) the task-specific API, and 2) using a Graphical Causal Model (GCM). Having access to a fitted GCM simplifies the computation of almost every task, but it requires knowledge of the full causal graph. Hence if the full graph is known, we suggest using the GCM API. For most other tasks, the common API can be used.

Welcome your contributions and feedback. You can join in the discussion page here.

Current API for effect estimation

DoWhy started out by using a CausalModel class for the effect estimation task, and the current API provides an easy four-step process.

m = CausalModel(data, graph, ..)
identified_estimand = m.identify_effect()
estimate = m.estimate_effect(identified_estimand,                                                                 
                             method="dowhy.propensity_score_matching")
refute = m.refute_estimate(identified_estimand, estimate,
                           method="placebo_treatment_refuter")

However, different tasks may need to use the CausalModel differently. Further, object-oriented methods do not explicitly mention all the inputs required (e.g., the identification method does not require access to data).

To make the function arguments explicit and avoid book-keeping inside the CausalModel class for different tasks, we propose switching to a functional approach. The idea is that the same end-user code will work in the new API, often just by replacing CausalModel.method with dowhy.method([data,graph]) where one of the two parameters may be optional. Note that individual identifiers, estimators, and refuters still remain as classes. It is only the user-facing functions that are being changed to a functional API.

In addition, we would like to enable MyPy types for the DoWhy API.

So there are two goals: 1) move to a functional API; 2) add mypy types in the API.

New API proposal for effect estimation

Effect inference API

# Ia. Create a networkx DiGraph as causal graph
# This can be later expanded to support ADMGs or other mixed graphs. To support this flexibility, we should use a Protocol class in Python. We can use the DirectedGraph class in https://github.com/py-why/dowhy/blob/main/dowhy/gcm/graph.py.

causal_graph = networkx.DiGraph(...)
# Ia. [Alternative] uses causal discovery
# user can edit the graph after visualizing it, using standard networkx operations
...
# Ib. The user can validate whether the constraints from the graph are satisfied by data.
# This functionality is currently in https://github.com/py-why/dowhy/blob/main/dowhy/causal_refuters/graph_refuter.py. We need to write a refute_graph function that wraps around that class.
validate = dowhy.refute_graph(causal_graph, data)

# II. Identify causal effect and return target estimands. This should be an easy refactor on top of the current causalmodel.identify_effect()
estimand = dowhy.identify_effect(causal_graph
                                 action_node="X", # we use "action" because it is a more general term  
                                 outcome_node="Y",
                                 observed_nodes=...)
# III. Fit and estimate the target estimand
# estimate = dowhy.estimate_effect(data, estimand, action_value=..., control_value=...) # method is an optional parameter here (default is "auto", meaning that we choose the best estimator based on the estimand. Internally, for a given estimation method EstMethod, this will initialize the class constructor, fit the model and return the estimates: EstMethod(estimand).fit(data).estimate_effect(action_value=..., control_value=...). 

# If the user knows the method, they can provide. Note that the method_params are now explicitly provided by user in the LinearRegressionEstimator constructor.
estimate = dowhy.estimate_effect(data, estimand, method=dowhy.LinearRegressionEstimator(, ..), action_value=..., control_value=...)

# Or, directly calling the causal estimator class (for advanced users)
estimator = dowhy.LinearRegressionEstimator(estimand)
estimator.fit(data)

# IIIb. Estimate the target estimand using a statistical method.
estimate = estimator.estimate_effect(action_value=..., control_value=...)
# IIIb. Alternative. 
estimate = estimator.do(value=...) - estimator.do(value=...)
                                     
# IV. Refute the obtained estimate using multiple robustness checks. the below call automatically chooses the refutation tests applicable.   
dowhy.refute_estimate(estimate, estimand, data)

# or user can specify the refutation test. Note that the method_params now go inside the constructor for AddRandomCommonCause
# Also the method argument is a list because typically people want to run multiple refutation tests.
dowhy.refute_estimate(estimate, estimand, data, method=[dowhy.AddUnobservedCommonCause(..)])

# Alternatively, people can call the refuter directly. 
dowhy.AddUnobservedCommonCause(estimate, estimand, data, ...).refute_estimate()

GCM-based API

# I. Create a networkx DiGraph as causal graph
causal_graph = networkx.DiGraph(...)


# II. Identify causal effect and return target estimands
estimand = dowhy.identify_effect(causal_graph
                                 action_node=["X1", “X2”], 
                                 outcome_node="Y",
                                 observed_nodes=...)

# III-pre. (optional) Assign causal models to nodes
dowhy.set_causal_model(causal_graph, "X", MyConditionalStochasticModel())

# IIIa. Auto-assign causal models to nodes and fit them to the data
scm = dowhy.fit_scm(causal_graph, data)

# IIIb. Estimate effect 
# do operation produces interventional samples
Y1 = dowhy.do(scm, estimand, input_values=[1])
Y0 = dowhy.do(scm, estimand, input_values=[1])
estimate = np.mean(Y1) – np.mean(Y0)

# IV. Refute estimate (same as before)
dowhy.refute_estimate(estimate, estimand, data, causal_graph,
                      method_name="random_common_cause")

Extension for conditional average treatment effect estimation

Identify Effect

We add an extra 2 arguments to identify effect for CATE estimation.

estimand = dowhy.identify_effect(
    causal_graph, 
    action_node, 
    outcome_node, 
    grouping_nodes=null,     # for heterogeneous treatment effect or interaction analysis
                        # (i.e., pseudo-effect modifier or effect modifier)
    grouping_analysis=[none,fixed,causal]) # default is none  --- EMK: changed "correlational" to "fixed" following PSantana's feedback

How does identify_effect change to support CATE?

If grouping_nodes is null, then grouping_analysis must be null as well.

If grouping_nodes is non-null, then grouping_analysis must not be null or none.

If grouping_nodes is "*" then:

  1. grouping_nodes will be determined by (a) the potential effect modifiers as determined by the causal graph; and (b) the effect inference algorithm.

grouping_analysis is intended to capture whether or not we are doing a heterogeneous effect analysis and, if so, whether we are doing a correlational analysis (i.e., a heterogeneous effect analysis where effect modifiers are treated as correlational definitions of subgroups) or a causal analysis of interactions between the main treatment and the effect modifiers.

If grouping_analysis is null or none, then identify_effect works as existing.

  1. grouping_nodes must be null as well.

If grouping_analysis is 'fixed', then we are analyzing $P(Y|do(action_{node}),grouping_{nodes})$:

  1. identification happens as existing.
  2. grouping_nodes are tested for being ancestors of the outcome and action nodes
  • throw error if they are descendants of outcome * warn if they are ancestors of action nodes and d-separated from outcome by action nodes * warn if they are descendants of action nodes but not descendants of any other nodes * While the warnings above are valid questions to ask, unless the causal graph is missing an edge or node, there should little heterogeneity in this case (?)
  1. Question? should a warning be thrown if the system is assumed to be additive?

If group_analysis is causal, then we are analyzing $P(Y|do(action_{node}),do(grouping_{nodes}))$:

  1. this is more complicated. we have to figure out appropriate control groups, etc. At a minimum, there's a lot of checks to do on the grouping_nodes, similar to above.
  2. initial "implementation" will be to throw an error.
  3. Q: an alternative to this API would be to add multiple action nodes for causal interaction analysis, and leave grouping only for correlational analyses

Estimate Effect

estimator = dowhy.LinearRegressionEstimator(estimand)
estimator.fit(data)

estimate = estimator.estimate_effect( 
                action_value = ..., 
                control_value = ...,
                grouping_values = ...)
How does estimate_effect change to support CATE?

If the grouping_values is null, then the estimand's group_analysis must have been none.

If the grouping_values is non-null, then the estimand's group_analysis must have been set to correlational or causal.

If grouping_values is * then the estimation algorithm will return heterogeneous treatment effects over algorithmically-determined groups.

If grouping_values is a single set of values, then estimate_effect will return the treatment effect over the specified group. This set of values must have the dimensionality of grouping_{nodes} as specified in the identify_effect function. If any value is given as * then the estimation algorithm will return heterogeneous treatment effects over algorithmically determined subgroups. If any value is given as a all then the estimation algorithm will return a heterogeneous treatment effect averaged over all values of that feature. The returned estimate object includes a field identifying the subgroup. If any value is a *, then the returned value will be an array of estimates (since the algorithm may split the given value into algorithmically determined subgroups)

If grouping_values is a collection of value-sets, then estimate_effect will return treatment effect over each of the given groups, as described above. The returned estimate object is an array of estimates, each including a field identifying the subgroup. The returned array may not include elements for every requested group, and may include elements for groups that were not requested.

Extensions to support effect estimation over panel data and time series.

Model graph

causal_graph = networkx.DiGraph(...)

When modeling panel data, our causal graph must include indexed nodes. Any node indexed in the causal graph must be indexed relative to a running variable. I.e., nodes may not have absolute indices or be indexed by other features in the graph. They may only be indexed by a running variable, plus or minus a constant. E.g., if $t$ is a running variable, then $n_{t-k}$ may be a node in the graph, where $k$ is a constant.

There may be multiple running variables, e.g., to accomodate spatial and temporal data. Note that graphs, of course, must still be DAGs, etc.

Note: Zeroth implementations, while we are working out the details of DAG representations, may simply be "dummy" graphs that are tagged to support specific algorithms. See discussion below on implementation notes.

Identify effect

Effect identification is not, in principle, changed. However, identification algorithms must be modified to handle indexed nodes. The simplest approach may be for dowhy to unroll the graph $N$ times, and then add appropriate unobserved confounding at the last layer for correctness.

Estimate effect

Conceptually, the data behind a causal graph includes both unindexed data (i.e., constant data that is not indexed by any running variable); data indexed by some but not all running variables; and data that is indexed by all running variables.

For example, consider a causal graph governed by 3 running variables $i,j,k$. A constant feature $A$ will have $A=A_{i,j,j}$ for all values of $i,j,k$. A feature indexed by a subset of running variables will have $A_{i} = A_{i,j,k}$ for all values of $j,k$.

Estimate effect takes a dataframe that is either jagged, long-form, or sparse.

  • Long-form dataframe - a long-form dataframe creates a representation of multi-dimensional data that duplicates non-indexed or partially-indexed variables. E.g., given a constant feature $A$ and a time index $t$, then $A_t$ will be set to the constant $A$ for all values of the index $t$.

  • Jagged dataframe

  • Sparse dataframe - Like a long-form dataframe, but all index values must be columns in the dataframe as well.

Q: should be more explicit about the data type: MultiIndex?

Q: is a unit index just another running variable? maybe there's a reserved index i for the unit index. For backwards compatibility, we can assume that a causal graph that has no explicit running indices has a unit index of i under each causal graph. that should give us consistency of representations.

Panel data estimation methods that return only LATE estimates (e.g., regression discontinuities) will return results that are marked with the subgroup for which the estimate is valid like CATE estimation proposed above.

First implementation note

For the experimental or 0th implementation, I expect that the we will not have support for analyzing arbitrary indexed causal graphs. Rather, we will provide helper functions to create specific causal graphs for 2-3 key panel data analysis methods. These helper functions will create easily recognizable indexed causal graphs, and all other indexed causal graphs will throw a not-yet-implemented error.

The purpose of this is to keep the API consistent across effect estimation methods, and then to get the basic algorithms working. Once we have the basic algorithms working, we can then iteratively improve the graphical and non-graphical representations of necessary assumptions in the causal model, and automation or generalization of identification strategies.

I expect long-form dataframes to be the initial implementation of data representation for most algorithms. Jagged and/or sparse data frames will not be supported or will be converted to long-form dataframes initially.

I expect the following panel data algorithms will be supported initially:

  • EconML time series analyses (e.g., dynamic treatment analysis)
  • difference-in-difference
  • synthetic controls
  • synthetic difference-in-difference
  • sparse synthetic controls (SparseSC)

We may add support unrolling of certain time series data for analysis via non-panel analysis methods.

We've also received requests for survival analysis methods over panel data.

Panel data Examples

In the below examples, I assume the addition of a helper function to create the initial causal graph consistent with the selected estimator.

Synthetic Controls

# The SyntheticControlsGraph helper function creates a causal graph representation that captures the assumptions necessary for running a synthetic control analysis.
# (note: non-graphical assumptions should also included explicitly.  first implementation might capture with a simple tag on the causal graph while we 
# develop a more sophisticated approach to capturing non-graphical assumptions)
causalgraph = dowhy.panelanalysis.SyntheticControlsGraph(features = ['A','B','C'],
                                            outcomes = ['Y'],
                                            treatment_status = 'T',
                                            running_index = 't')

estimand = dowhy.identify_effect(causalgraph,
                                 action_node = ['T'],
                                 outcome_node = ['Y'],

# load panel data (long-form dataframe)
panel_data = pd.read_csv('panel_data.csv')

estimator = dowhy.SyntheticControlsEstimator(estimand))

# The SyntheticControlsEstimator's fit method should validate non-graphical assumptions on the data.  E.g., that once units are treated, they remain treated.
estimator.fit(paneldata)

estimate = estimator.estimate_effect(action_value=..., control_value=...)
                                     
# Refute the obtained estimate using multiple robustness checks.  
dowhy.refute_estimate(estimate, estimand, paneldata, causal_graph,
                      method_name="random_common_cause")

Difference-in-Difference

# The Difference-in-Differences Graph helper function creates a causal graph representation that captures the assumptions necessary for running a diff-in-diff analysis.
causalgraph = dowhy.panelanalysis.DiffInDiffGraph(outcomes = ['Y'],
                                            treatment_status = 'T',
                                            running_index = 't')

estimand = dowhy.identify_effect(causalgraph,
                                 action_node = ['T'],
                                 outcome_node = ['Y'],

# load panel data (long-form dataframe)
panel_data = pd.read_csv('panel_data.csv')

estimator = dowhy.SyntheticControlsEstimator(estimand))

# the SyntheticControlsEstimator's fit method should validate non-graphical assumptions on the data.  E.g., that once units are treated, they remain treated.
estimator.fit(paneldata)

estimate = estimator.estimate_effect(action_value=..., control_value=...)
                                     
# Refute the obtained estimate using multiple robustness checks.  
dowhy.refute_estimate(estimate, estimand, paneldata, causal_graph,
                      method_name="random_common_cause")


API for new tasks: Causal Prediction


See API Proposal for Causal Prediction


API for new tasks: Attribution


Our current plans include GCM-based attribution. Looking for suggestions on non-GCM based methods for causal attribution.

GCM-based API

# Fit an SCM as before
# Analyze root causes from anomalous data.
anomaly_scores = dowhy.anomaly_scores(scm, potentially_anomalous_data)
# Perform interventional analysis on particular node. Both soft and hard(do) intervention supported.
intervention_samples = dowhy.intervene(scm, target_node="X", intervention_func)
# Attribute distribution changes
change_attributions = dowhy.distribution_change(scm, data, new_data)

API for new tasks: Counterfactual Estimation


Counterfactuals do require access to a fitted GCM.

GCM-based API

# assume causal graph is provided
scm = dowhy.fit_scm(causal_graph, data)
identified_cf = dowhy.identify_counterfactual(scm, evidence, action, outcome)
cf = dowhy.estimate_counterfactual(identified_cf, data)
dowhy.refute_counterfactual(cf, model=scm)