-
Notifications
You must be signed in to change notification settings - Fork 935
API proposal for v1
DoWhy's current API was designed for the effect estimation task. As the library expands to more tasks, we revisit some of the design decisions and propose an updated API. In particular, the most significant change is moving from an object-oriented API to a functional API. Beyond this change, our goal is to retain the current effect estimation API with exactly the same input-output signature while adding new API functions to include other tasks.
As you can see below, we envision two ways of achieving the each task: 1) the task-specific API, and 2) using a Graphical Causal Model (GCM). Having access to a fitted GCM simplifies the computation of almost every task, but it requires knowledge of the full causal graph. Hence if the full graph is known, we suggest using the GCM API. For most other tasks, the common API can be used.
Welcome your contributions and feedback. You can join in the discussion page here.
DoWhy started out by using a CausalModel class for the effect estimation task, and the current API provides an easy four-step process.
m = CausalModel(data, graph, ..)
identified_estimand = m.identify_effect()
estimate = m.estimate_effect(identified_estimand,
method="dowhy.propensity_score_matching")
refute = m.refute_estimate(identified_estimand, estimate,
method="placebo_treatment_refuter")
However, different tasks may need to use the CausalModel
differently. Further, object-oriented methods do not explicitly mention all the inputs required (e.g., the identification method does not require access to data).
To make the function arguments explicit and avoid book-keeping inside the CausalModel class for different tasks, we propose switching to a functional approach. The idea is that the same end-user code will work in the new API, often just by replacing CausalModel.method
with dowhy.method([data,graph])
where one of the two parameters may be optional.
# Ia. Create a networkx DiGraph as causal graph
# This can be later expanded to support ADMGs or other mixed graphs
# But the assumption is that the object will be a networkx Graph object or at least follows the basic graph interface that we define below.
causal_graph = networkx.DiGraph(...)
# Ia. [Alternative] uses causal discovery
# user can edit the graph after visualizing it, using standard networkx operations
...
# Ib. The user can validate whether the constraints from the graph are satisfied by data
validate = dowhy.refute_graph(causal_graph, data)
# II. Identify causal effect and return target estimands
estimand = dowhy.identify_effect(causal_graph
action_node="X",
outcome_node="Y",
observed_nodes=...)
# IIIa. Fit estimand
# Directly calling the causal estimator class
estimator = dowhy.LinearRegressionEstimator(estimand)
estimator.fit(data)
# IIIb. Estimate the target estimand using a statistical method.
estimate = estimator.estimate_effect(action_value=..., control_value=...)
# IIIb. Alternative.
estimate = estimator.do(value=...) - estimator.do(value=...)
# IV. Refute the obtained estimate using multiple robustness checks.
dowhy.refute_estimate(estimate, estimand, data, causal_graph,
method_name="random_common_cause")
# I. Create a networkx DiGraph as causal graph
causal_graph = networkx.DiGraph(...)
# II. Identify causal effect and return target estimands
estimand = dowhy.identify_effect(causal_graph
action_node=["X1", “X2”],
outcome_node="Y",
observed_nodes=...)
# III-pre. (optional) Assign causal models to nodes
dowhy.set_causal_model(causal_graph, "X", MyConditionalStochasticModel())
# IIIa. Auto-assign causal models to nodes and fit them to the data
scm = dowhy.fit_scm(causal_graph, data)
# IIIb. Estimate effect
# do operation produces interventional samples
Y1 = dowhy.do(scm, estimand, input_values=[1])
Y0 = dowhy.do(scm, estimand, input_values=[1])
estimate = np.mean(Y1) – np.mean(Y0)
# IV. Refute estimate (same as before)
dowhy.refute_estimate(estimate, estimand, data, causal_graph,
method_name="random_common_cause")
Here we describe how the same four steps---modeling, identification, estimation and refutation---can be useful for the causal prediction task. For learning the predictor, DoWhy will support multiple causal representation learning algorithms. Below we show an example code with Invariant Risk Minimization (Arjovsky et al., 2019).
# Create the graph. In addition to structure, this graph includes node and edge
# attributes. Node attributes tell us about the node (e.g., domain feature), edge
# attributes tell us about the edge(e.g., monotonic effect, etc.)
graph = nx.DiGraph(..)
ctx=dowhy.causal_constraints(graph, target_var="Y")
# Alternatively, add the attributes manually. This will be especially useful for
# image/language datasets where the graph over input variables is infeasible.
ctx = [
Monotonic_effect(feature_name="age", target="Y")
ZeroEffect(feature_name="gender", target="Y")
CustomEffect(feature_name="educ", target="Y", effect_fn =fn)
AttributeInvariance(domain_variable="D", target="Y")
]
# Identify whether a predictor is possible, P(Y|do(X)) where X is all the observed parents of Y
estimand=dowhy.identify_prediction(graph, m, constraints=ctx) # returns whether predictor is possible, if yes, the estimand.
predm=dowhy.predictors.IRM()
predm.fit(data)
ytest=predm.predict(X) # or ytest=predm.do(X)
dowhy.refute_prediction(ytest, model=predm, graph ...)
The constraints creation steps remain the same. The prediction now is done by the fitted GCM.
# Assume constraints ctx are provided
estimand=dowhy.identify_prediction(graph, m, constraints=ctx)
# (optional) Assign causal models to nodes
dowhy.set_causal_model(causal_graph, "Y", MyConditionalStochasticModel())
# Auto-assign causal models to nodes and fit them to the data
scm = dowhy.fit_scm(causal_graph, data)
# Predict
# do operation produces interventional samples
ytest = np.mean(dowhy.do(scm, estimand, input_values=X))
dowhy.refute_prediction(ytest, model=scm, graph ...)
Our current plans include GCM-based attribution. Looking for suggestions on non-GCM based methods for causal attribution.
# Fit an SCM as before
# Analyze root causes from anomalous data.
anomaly_scores = dowhy.anomaly_scores(scm, potentially_anomalous_data)
# Perform interventional analysis on particular node. Both soft and hard(do) intervention supported.
intervention_samples = dowhy.intervene(scm, target_node="X", intervention_func)
# Attribute distribution changes
change_attributions = dowhy.distribution_change(scm, data, new_data)
Counterfactuals do require access to a fitted GCM.
# assume causal graph is provided
scm = dowhy.fit_scm(causal_graph, data)
identified_cf = dowhy.identify_counterfactual(scm, evidence, action, outcome)
cf = dowhy.estimate_counterfactual(identified_cf, data)
dowhy.refute_counterfactual(cf, model=scm)