-
Notifications
You must be signed in to change notification settings - Fork 935
API Proposal for Causal Prediction (and optionally representation learning)
We would like to enable causally robust prediction using DoWhy’s four step API. Prediction under distribution change is a common use case. While there are many (non-causal) approaches proposed for this problem, causal models are especially useful because they capture the stable relationships from the data-generating process. There are a broad variety of causal approaches to robust prediction, from approaches that rely on learning a full SCM to methods that focus on invariant identification under certain assumptions. Our goal is to provide an easy-to-use causal prediction interface that can generalize across these methods, as well as encourage cross-fertilization of ideas.
We are also exploring representation learning since it is essential for any unstructured data like text or images. This proposal contains an experimental section on representation learning and how it relates to causal prediction.
There are two types of users for this API, 1) data scientists who want to apply prediction to their problem; 2) researchers/developers who would like to contribute new algorithms. For the first, we’d like to keep the API simple, including sensible defaults wherever possible. For the second, we want the API to be easily extendable for new, future algorithms.
For any prediction problem, we assume that the user will provide some causal constraints/context. These may be derived from domain knowledge (e.g., monotonic or null effect of an attribute), describe data collection (e.g, domain attribute in a dataset with multiple sources), provide details on the data-generating process (confounded, independent or selected attribute), or any other context that may be used by future algorithms. Because of the similarity to the Context class from causal discovery, there may be a benefit to use the same class for prediction as well. We can subclass it as DiscoveryContext and PredictionContext.
## Prediction model uses causal constraints (context) that are simply presented as a list.
## These constraints are of two types: 1) the kind of relationship from a feature to outcome (e.g, zero effect or monotonic effect); 2) special kinds of variables that enforce different constraints based on the graph structure.
## These constraints can also be thought of as Context (from the causal discovery API) and may just inherit from that class. These constraints lead to an internal graph representation (that is inferred from the constraints, not learnt)
effect_ctx = [
Monotonic_effect(feature_name="age", target="Y")
ZeroEffect(feature_name="gender", target="Y"),
CustomEffect(feature_name="educ", target="Y", effect_fn =fn),
]
spurious_attribute_ctx = [
OutcomeConfoundedAttribute(feature_name=”color”, target=”Y”, max_distribution_shift=0.7, cause_of_input=True),
OutcomeSelectedAttribute(feature_name=”location”, target=”Y”), # max shift is unconstrained by default, attributes cause target Y by default
OutcomeIndependentAttribute(feature_name=”rotation”, target=”Y”)
]
ctx=effect_ctx + spurious_attribute_ctx
# All parameters except ctx are optional. If parents of target variable are not provided, we assume that parents of Y are unobserved and the “parents” representation needs to be learnt from input features.
graph=dowhy.build_graph(ctx,
parents_target,
ancestors_target=None,
children_target=None,
descendants_target=None,
parents_children_target=None) # for Markov boundary
## Another way to initialize the context is to input the graph along with node and edge attributes. In addition to structure, this graph includes node and edge
# attributes. Node attributes tell us about the node (e.g., outcome-confounded or outcome-independent feature), edge attributes tell us about the edge (e.g., monotonic effect, etc.)
## TODO: to do this well, we need a common DAG format. More generally, we need a common graph format, so that node attributes and edge attributes can be specified. For now, assume these are networkx attributes.
graph = nx.DiGraph(..)
# this call recovers the constraints relevant for target Y from the graph
ctx=dowhy.causal_constraints(graph, target_var="Y") .
Once we have the context, we’d like to check whether a causal predictive model is possible given the available data. Like with effect inference where the conditions for identification varied with the type of estimation method (e.g, conditioning versus instrumental variable), the conditions for causal prediction identification depend on the type of causal prediction method. Some examples are,
- Regularization-based predictors (e.g., MMD, IRM): Here the identification criterion is that all nodes appearing the regularization constraints are observed.
- Causal graph-based predictors: Identification criterion is whether all parents are observed or (possibly) other criteria that include all ancestors, anti-causal prediction, markov boundary, etc.
# Identify whether a causal predictor is possible, P(Y|do(X)) where X is the desired input features for the prediction task. Returns whether predictor is possible, if yes, the estimand containing the either/all of the necessary info/formula to build the prediction model. E.g., it can provide 1) list of conditional independence constraints to be applied; 2) if the predictor can be constructed using GCM (are all parents observed?), etc.
estimand=dowhy.identify_predictor(graph=None, context=ctx, input=[“X1”, “X2”], target=”Y”) # either of graph or context must be provided. If context is provided, a graph is internally built automatically.
For estimation, we want to support all kinds of causal prediction models. To start with, we will implement IRM [1], MMD, Conditional MMD [2], and Causally Adaptive Constraints (CACM) [3]. Because CACM adapts the regularization based on the context provided, we can have it as the “default” for new users.
# recipe for new users
predm=dowhy.learn_predictor(estimand, data, method=”CACM”, method_params={})
ytest=predm.predict(X) # or predm.do(X)
Below is an example using IRM when the user knows their preferred algorithm.
# alternative recipe for advanced users
predm=dowhy.predictors.IRM(estimand, ..other params..)
predm.fit(data)
ytest=predm.predict(X) # or predm.do(X)
# Alternatively, we can use the GCM model. The steps of model (context), identification, estimation remain the same.
estimand=dowhy.identify_predictor(graph=graph, target=”Y”) # check if all parents are observed
# Auto-assign causal models to nodes and fit them to the data
scm = dowhy.fit_scm(graph, estimand, data) # estimand is useful to know which functions to fit
# Predict: do operation produces interventional samples
ytest=scm.predict(X, target=”Y”) # or ytest = dowhy.do(scm, X, target=”Y”))
# Here we can do cross-validation over a new distribution if available, simulate such distributions, check if the constraints from the estimand are satisfied by the predictor, or create counterfactual examples to test the model.
# cross-validation using sklearn
from sklearn.model_selection import cross_val_score
cross_val_score(predm, Xnew, ynew, cv=5, scoring=’accuracy’)
# if OOD data is not available, then generate OOD data first and then cross-validate.
Xnew, ynew = Dowhy.generate()
cross_val_score(predm, Xnew, ynew, cv=5, scoring=’accuracy’)
# finally, we can use refutations that check estimand constraints, counterfactual properties, etc.
dowhy.refute_prediction(estimand, predm, data, ...)
For comparing different algorithms, here is a simple example to enable easy iteration.
# ctx, graph construction and estimand remain the same.
methods = [IRM(), CACM()]
for m in methods:
predm=dowhy.learn_predictor(estimand, data, method=m, method_params={..})
ytest=predm.predict(X) # or predm.do(X)
Data for prediction can be of two types, 1) tabular data where features are provided; and 2) unstructured data like text where a representation over features needs to be learnt. In the prediction API, that is a low-level detail. The algorithms are free to learn representations, but the user only cares about the final prediction. In some cases, however, it may be useful to access the representation and use it for other tasks. For example, a representation may be useful for causal discovery and causal inference over text/image data. In effect inference of treatment over Y, one may learn a representation using textual data W that predicts Y (while satisfying some constraints to avoid impact of treatment), and then use the representation as a common_cause of treatment and Y for effect inference. If a prediction model already has been learnt, we can access the representation as,
# access representation from fitted model.
predm.causal_repr(…optional_params…)
If our goal is to learn representation only that does not utilize a predictive loss, we can instead call,
# context ctx is available
dowhy.learn_repr(ctx, data, method=”contrastive”, ...)
-
Q. Does it belong in DoWhy or in a separate package in py-why?
- A. We can take an approach similar to causal discovery. We aim to add the basic API and example algorithms within DoWhy. If there is more interest and people want to add more complex algorithms (and there’s enough support), it may make sense to start a new repository.
- Arjovsky, Martin, et al. "Invariant risk minimization." arXiv preprint arXiv:1907.02893 (2019).
- Makar, Maggie, et al. "Causally motivated shortcut removal using auxiliary labels." AISTATS (2022).
- Kaur, Jivat Neet, Emre Kiciman, and Amit Sharma. "Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization." arXiv preprint arXiv:2206.07837 (2022).