diff --git a/notebooks/CausalityDataset setup.ipynb b/notebooks/CausalityDataset setup.ipynb new file mode 100644 index 00000000..e328bf46 --- /dev/null +++ b/notebooks/CausalityDataset setup.ipynb @@ -0,0 +1,1863 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f3a2f126", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Setting up the data and causal model: CausalityDataset\n", + "\n", + "This notebook demonstrates how to use and configure `CausalityDataset` using an arbitrary `pd.DataFrame`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "d43137b0", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os, sys\n", + "import warnings\n", + "warnings.filterwarnings('ignore') # suppress sklearn deprecation warnings for now..\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# the below checks for whether we run dowhy, causaltune, and FLAML from source\n", + "root_path = root_path = os.path.realpath('../..')\n", + "try:\n", + " import causaltune\n", + "except ModuleNotFoundError:\n", + " sys.path.append(os.path.join(root_path, \"causaltune\"))\n", + "\n", + "try:\n", + " import dowhy\n", + "except ModuleNotFoundError:\n", + " sys.path.append(os.path.join(root_path, \"dowhy\"))\n", + "\n", + "try:\n", + " import flaml\n", + "except ModuleNotFoundError:\n", + " sys.path.append(os.path.join(root_path, \"FLAML\"))\n", + " \n", + " \n", + " \n", + "from causaltune import CausalTune\n", + "from causaltune.datasets import synth_ihdp, iv_dgp_econml, generate_non_random_dataset\n", + "from causaltune.data_utils import CausalityDataset\n", + "from causaltune.dataset_processor import CausalityDatasetProcessor" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e072c202", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# this makes the notebook expand to full width of the browser window\n", + "from IPython.core.display import display, HTML\n", + "display(HTML(\"\"))" + ] + }, + { + "cell_type": "markdown", + "id": "c2a0429f", + "metadata": {}, + "source": [ + "### Random assignment \n", + "We first illustrate the model setup with a subset of data from the Infant Health and Development Program (IHDP)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0efc918c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3
015.599916-0.528603-0.3434551.128554
106.875856-1.736945-1.8020020.383828
202.996273-0.807451-0.202946-0.360898
301.3662060.3900830.596582-1.850350
401.963538-1.045229-0.6027100.011465
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554\n", + "1 0 6.875856 -1.736945 -1.802002 0.383828\n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898\n", + "3 0 1.366206 0.390083 0.596582 -1.850350\n", + "4 0 1.963538 -1.045229 -0.602710 0.011465" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df = synth_ihdp(return_df=True).iloc[:,:5]\n", + "display(df.head())" + ] + }, + { + "cell_type": "markdown", + "id": "c5bce66b", + "metadata": {}, + "source": [ + "Generally, at least three arguments have to be supplied to `CausalityDataset`:\n", + "- `data`: input dataframe\n", + "- `treatment`: name of treatment column\n", + "- `outcomes`: list of names of outcome columns; provide as list even if there's just one outcome of interest\n", + "\n", + "In addition, if the propensities to treat are known, then provide the corresponding column name(s) via `propensity_modifiers`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bb50909e", + "metadata": {}, + "outputs": [], + "source": [ + "cd = CausalityDataset(data=df, treatment='treatment', outcomes=['y_factual'])" + ] + }, + { + "cell_type": "markdown", + "id": "73b6395a", + "metadata": {}, + "source": [ + "The next step is to use `cd.preprocess_dataset()` to deal with missing values, remove outliers etc." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8803d695", + "metadata": {}, + "outputs": [], + "source": [ + "cd.preprocess_dataset()" + ] + }, + { + "cell_type": "markdown", + "id": "dafa93e0", + "metadata": {}, + "source": [ + "The causal model is built by assuming that all remaining features are `effect_modifiers`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6695f65f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['x1', 'x2', 'x3']\n" + ] + } + ], + "source": [ + "print(cd.effect_modifiers)" + ] + }, + { + "cell_type": "markdown", + "id": "50447729", + "metadata": {}, + "source": [ + "Subsequently, use the preprocessed `CausalityDataset` object for training as follow: `CausalTune.fit(cd, outcome='y_factual')`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "eb9ebea5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting a Propensity-Weighted scoring estimator to be used in scoring tasks\n", + "Propensity Model Fitted Successfully\n" + ] + } + ], + "source": [ + "ct = CausalTune(components_time_budget=5,) \n", + "ct.fit(data=cd, outcome='y_factual')" + ] + }, + { + "cell_type": "markdown", + "id": "e8cf75fb", + "metadata": {}, + "source": [ + "The causal graph that CausalTune uses is " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6b9a1ad6", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "ct.causal_model.view_model()" + ] + }, + { + "cell_type": "markdown", + "id": "f0ec03d0", + "metadata": {}, + "source": [ + "*Note that the variable `random` can be ignored and has no real meaning for the causal model.*" + ] + }, + { + "cell_type": "markdown", + "id": "80318c33", + "metadata": {}, + "source": [ + "#### Adding common causes\n", + "\n", + "If we had reason to assume that for instance `x1` and `x2` are `common causes` instead of `effect modifiers`, this can be made explicit:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6babd054", + "metadata": {}, + "outputs": [], + "source": [ + "cd = CausalityDataset(data=df, treatment='treatment', outcomes=['y_factual'], common_causes=['x1', 'x2'])" + ] + }, + { + "cell_type": "markdown", + "id": "256f2054", + "metadata": {}, + "source": [ + "The causal graph becomes" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "510157f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting a Propensity-Weighted scoring estimator to be used in scoring tasks\n", + "Propensity Model Fitted Successfully\n" + ] + } + ], + "source": [ + "cd.preprocess_dataset()\n", + "ct = CausalTune(components_time_budget=5,) \n", + "ct.fit(data=cd, outcome='y_factual')\n", + "ct.causal_model.view_model()" + ] + }, + { + "cell_type": "markdown", + "id": "ca35fcef", + "metadata": {}, + "source": [ + "For how to proceed further with CausalTune, see for instance [here](https://github.com/py-why/causaltune/blob/main/notebooks/Random%20assignment%2C%20binary%20CATE%20example.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "c1be7581", + "metadata": {}, + "source": [ + "### Instrumental variable identification\n", + "\n", + "In other problems of causal inference, one may seek to follow an instrumental variable approach ([Example notebook](https://github.com/py-why/causaltune/blob/main/notebooks/Comparing%20IV%20Estimators.ipynb)). " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2a35636e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " x1 x2 x3 x4 y treatment Z\n", + "0 -0.662658 1.124321 -1.699940 -0.379268 5.236122 0 0\n", + "1 -0.788565 1.336684 -0.539586 -0.785838 12.039615 1 1\n", + "2 -0.344655 -0.204201 -1.267158 0.898114 23.469351 1 1\n", + "3 0.125284 -0.557028 0.403744 0.579168 5.300115 0 0\n", + "4 0.356507 0.330607 0.430286 1.201554 12.855370 0 0\n" + ] + } + ], + "source": [ + "#load data\n", + "df = iv_dgp_econml(p=4).data\n", + "del df['random']\n", + "print(df.head(5))" + ] + }, + { + "cell_type": "markdown", + "id": "a012cdff", + "metadata": {}, + "source": [ + "Suppose we want to use $Z$ as an instrument." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c9be746a", + "metadata": {}, + "outputs": [], + "source": [ + "cd = CausalityDataset(\n", + " data=df, \n", + " treatment='treatment',\n", + " outcomes=['y'],\n", + " instruments=['Z']\n", + " )\n", + "cd.preprocess_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0bfd06a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outcomes: ['y']\n", + "Treatment: treatment\n", + "Instruments: ['Z']\n", + "Effect modifiers: ['x1', 'x2', 'x3', 'x4']\n" + ] + } + ], + "source": [ + "print('Outcomes:', cd.outcomes)\n", + "print('Treatment:', cd.treatment)\n", + "print('Instruments:', cd.instruments)\n", + "print('Effect modifiers:', cd.effect_modifiers)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0e738f3e", + "metadata": {}, + "outputs": [], + "source": [ + "ct = CausalTune(\n", + " components_time_budget=5,\n", + " estimator_list=['iv.econml.iv.dml.DMLIV']\n", + " ) \n", + "ct.fit(data=cd, outcome='y')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "83f847f9", + "metadata": {}, + "outputs": [], + "source": [ + "ct.causal_model.view_model()" + ] + }, + { + "cell_type": "markdown", + "id": "ecb28b61", + "metadata": {}, + "source": [ + "### Propensity modifiers\n", + "\n", + "If there are well-known propensity modifiers, it is also possible to make those explicit. This can, e.g., be used to pass them directly into the model instead of fitting a propensity weight model (for more details, see [here](https://github.com/py-why/causaltune/blob/main/notebooks/Propensity%20Model%20Selection.ipynb))." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b1407bbb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " T Y X1 X2 X3 X4 X5 propensity\n", + "0 0 1.650705 0.521524 -1.393497 0.010672 -0.828778 1.019257 0.245100\n", + "1 0 -0.888552 -0.782541 -1.384920 -0.233656 0.150249 -0.495169 0.205945\n", + "2 0 -0.516344 -0.154831 -0.098985 2.335176 -1.888928 -0.594854 0.235870\n", + "3 1 0.601679 0.109516 0.092910 0.525252 -1.172202 -0.177947 0.439021\n", + "4 0 0.569122 -0.365630 -0.343061 -0.420554 -0.995160 1.548502 0.335151\n" + ] + } + ], + "source": [ + "#load data\n", + "df = generate_non_random_dataset().data\n", + "del df['random']\n", + "print(df.head(5))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "1b906467", + "metadata": {}, + "outputs": [], + "source": [ + "cd = CausalityDataset(\n", + " data=df, \n", + " treatment='T',\n", + " outcomes=['Y'],\n", + " propensity_modifiers=['propensity']\n", + " )\n", + "cd.preprocess_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "71394906", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outcomes: ['Y']\n", + "Treatment: T\n", + "Propensity Modifiers: ['propensity']\n", + "Effect modifiers: ['X1', 'X2', 'X3', 'X4', 'X5']\n" + ] + } + ], + "source": [ + "print('Outcomes:', cd.outcomes)\n", + "print('Treatment:', cd.treatment)\n", + "print('Propensity Modifiers:', cd.propensity_modifiers)\n", + "print('Effect modifiers:', cd.effect_modifiers)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "359fd218", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting a Propensity-Weighted scoring estimator to be used in scoring tasks\n", + "Propensity Model Fitted Successfully\n" + ] + } + ], + "source": [ + "ct = CausalTune(\n", + " components_time_budget=5,\n", + ") \n", + "ct.fit(data=cd, outcome='Y')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "08e0ee9c", + "metadata": {}, + "outputs": [], + "source": [ + "ct.causal_model.view_model()" + ] + }, + { + "cell_type": "markdown", + "id": "818762bf-a3e7-426b-87e7-3cbcaa5d1ef8", + "metadata": {}, + "source": [ + "### Pre-processing of the test dataset based on the training set\n", + "You can also preprocess the data in the CausalityDataset using one of the popular category encoders: OneHot, WoE, Label, Target." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fd26bb39-e55f-4f76-b225-838ddb16675b", + "metadata": {}, + "outputs": [], + "source": [ + "unique_values_1 = ['A', 'B', 'C', 'D', 'E']\n", + "unique_values_2 = ['F', 'G', 'H', 'I', 'J', 'K']\n", + "unique_values_3 = ['L', 'M', 'N', 'O', 'P', 'Q', 'R']" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9354687b-6d4a-448a-813d-c8f21c761b8c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3
015.599916-0.528603-0.3434551.128554
106.875856-1.736945-1.8020020.383828
202.996273-0.807451-0.202946-0.360898
301.3662060.3900830.596582-1.850350
401.963538-1.045229-0.6027100.011465
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554\n", + "1 0 6.875856 -1.736945 -1.802002 0.383828\n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898\n", + "3 0 1.366206 0.390083 0.596582 -1.850350\n", + "4 0 1.963538 -1.045229 -0.602710 0.011465" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_train = synth_ihdp(return_df=True).iloc[:,:5]\n", + "display(df_train.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f22414f8-0624-4e04-9a3d-4dfe31e4f8f0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3
015.599916-0.528603-0.3434551.128554
106.875856-1.736945-1.8020020.383828
202.996273-0.807451-0.202946-0.360898
301.3662060.3900830.596582-1.850350
401.963538-1.045229-0.6027100.011465
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554\n", + "1 0 6.875856 -1.736945 -1.802002 0.383828\n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898\n", + "3 0 1.366206 0.390083 0.596582 -1.850350\n", + "4 0 1.963538 -1.045229 -0.602710 0.011465" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_test = synth_ihdp(return_df=True).iloc[:,:5]\n", + "display(df_test.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "70477eb1-9a14-4927-85ef-fb6888d432c7", + "metadata": {}, + "outputs": [], + "source": [ + "# Adding the category columns with random values\n", + "df_train['category_col1'] = np.random.choice(unique_values_1, len(df_train))\n", + "df_train['category_col2'] = np.random.choice(unique_values_2, len(df_train))\n", + "df_train['category_col3'] = np.random.choice(unique_values_3, len(df_train))\n", + "\n", + "df_test['category_col1'] = np.random.choice(unique_values_1, len(df_test))\n", + "df_test['category_col2'] = np.random.choice(unique_values_2, len(df_test))\n", + "df_test['category_col3'] = np.random.choice(unique_values_3, len(df_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ba538a59-a875-4766-a41a-f9099f3add16", + "metadata": {}, + "outputs": [], + "source": [ + "cd_train = CausalityDataset(\n", + " data=df_train,\n", + " treatment='treatment',\n", + " outcomes=['y_factual'],\n", + " effect_modifiers=['x1', 'x2', 'x3', 'category_col1', 'category_col2', 'category_col3']\n", + ")\n", + "\n", + "cd_test = CausalityDataset(\n", + " data=df_test,\n", + " treatment='treatment',\n", + " outcomes=['y_factual'],\n", + " effect_modifiers=['x1', 'x2', 'x3', 'category_col1', 'category_col2', 'category_col3']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "71fb2260-ca6e-4a7c-b220-8de92af82917", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3category_col1category_col2category_col3random
015.599916-0.528603-0.3434551.128554EKR1
106.875856-1.736945-1.8020020.383828AFM1
202.996273-0.807451-0.202946-0.360898DHO0
301.3662060.3900830.596582-1.850350DKR0
401.963538-1.045229-0.6027100.011465CKQ0
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 E \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 A \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 D \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 D \n", + "4 0 1.963538 -1.045229 -0.602710 0.011465 C \n", + "\n", + " category_col2 category_col3 random \n", + "0 K R 1 \n", + "1 F M 1 \n", + "2 H O 0 \n", + "3 K R 0 \n", + "4 K Q 0 " + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cd_train.data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "78d7813b-cc59-4b49-a92a-fb41bae4bd9d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3category_col1category_col2category_col3random
015.599916-0.528603-0.3434551.128554BHM1
106.875856-1.736945-1.8020020.383828CIM0
202.996273-0.807451-0.202946-0.360898BKR0
301.3662060.3900830.596582-1.850350CHP1
401.963538-1.045229-0.6027100.011465AHO1
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 B \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 C \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 B \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 C \n", + "4 0 1.963538 -1.045229 -0.602710 0.011465 A \n", + "\n", + " category_col2 category_col3 random \n", + "0 H M 1 \n", + "1 I M 0 \n", + "2 K R 0 \n", + "3 H P 1 \n", + "4 H O 1 " + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cd_test.data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "2c4767e4-7aa2-47f5-ad9d-d830e77d78b0", + "metadata": {}, + "source": [ + "You can select one of the categorical encoders: `\"onehot\", \"label\", \"target\", \"woe\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "7c0b8f72-6efb-4812-80f7-164557e9eea6", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_processor = CausalityDatasetProcessor()\n", + "dataset_processor.fit(\n", + " cd=cd_train,\n", + " encoder_type=\"label\"\n", + ")\n", + "cd_train = dataset_processor.transform(cd_train)\n", + "cd_test = dataset_processor.transform(cd_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "7b9ee8de-ecfd-475c-9369-4bc9abd81454", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3randomcategory_col1category_col2category_col3
015.599916-0.528603-0.3434551.1285541.0111
106.875856-1.736945-1.8020020.3838281.0222
202.996273-0.807451-0.202946-0.3608980.0333
301.3662060.3900830.596582-1.8503500.0311
401.963538-1.045228-0.6027100.0114650.0414
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 random category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 1.0 1 \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 1.0 2 \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 0.0 3 \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 0.0 3 \n", + "4 0 1.963538 -1.045228 -0.602710 0.011465 0.0 4 \n", + "\n", + " category_col2 category_col3 \n", + "0 1 1 \n", + "1 2 2 \n", + "2 3 3 \n", + "3 1 1 \n", + "4 1 4 " + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cd_train.data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "82ba4136-d406-4b3d-9b13-4bc765ab925d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3randomcategory_col1category_col2category_col3
015.599916-0.528603-0.3434551.1285541.0532
106.875856-1.736945-1.8020020.3838280.0442
202.996273-0.807451-0.202946-0.3608980.0511
301.3662060.3900830.596582-1.8503501.0436
401.963538-1.045228-0.6027100.0114651.0233
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 random category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 1.0 5 \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 0.0 4 \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 0.0 5 \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 1.0 4 \n", + "4 0 1.963538 -1.045228 -0.602710 0.011465 1.0 2 \n", + "\n", + " category_col2 category_col3 \n", + "0 3 2 \n", + "1 4 2 \n", + "2 1 1 \n", + "3 3 6 \n", + "4 3 3 " + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cd_test.data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "4ef24918-f676-4da8-b77a-ebf1347c7be9", + "metadata": {}, + "source": [ + "### Example of model training on transformed data\n", + "Now if `outcome_model=\"auto\"` in the CausalTune constructor, we search over a simultaneous search space for the EconML estimators and for FLAML wrappers for common regressors. The old behavior is now achieved by `outcome_model=\"nested\"` (Refitting AutoML for each estimator)." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "7016148c-d48e-4d1d-a951-160b29b6a37b", + "metadata": {}, + "outputs": [], + "source": [ + "# training configs\n", + "\n", + "# set evaluation metric\n", + "metric = \"energy_distance\"\n", + "\n", + "# it's best to specify either time_budget or components_time_budget, \n", + "# and let the other one be inferred; time in seconds\n", + "components_time_budget = 10\n", + "\n", + "# specify training set size\n", + "train_size = 0.7" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "3f6b2cfd-a26e-4c96-8504-7380459d1a3d", + "metadata": {}, + "outputs": [], + "source": [ + "ct = CausalTune(\n", + " estimator_list=[\n", + " \"DomainAdaptationLearner\",\n", + " \"CausalForestDML\",\n", + " \"ForestDRLearner\",\n", + " ],\n", + " metric=metric,\n", + " verbose=1,\n", + " components_time_budget=components_time_budget,\n", + " train_size=train_size,\n", + " outcome_model=\"auto\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "cbdc9c92-33c8-41c4-ab15-8e15703b5a56", + "metadata": {}, + "outputs": [], + "source": [ + "# run causaltune\n", + "ct.fit(data=cd_train, outcome=cd_train.outcomes[0])\n", + "\n", + "print('---------------------')\n", + "# return best estimator\n", + "print(f\"Best estimator: {ct.best_estimator}\")\n", + "# config of best estimator:\n", + "print(f\"Best config: {ct.best_config}\")\n", + "# best score:\n", + "print(f\"Best score: {ct.best_score}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "56d79fe7-bed7-4ccb-96bd-5b0843b149fb", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = ct.predict(cd_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "27bab1dd-abd2-41a4-b30c-62b4722d0872", + "metadata": {}, + "outputs": [], + "source": [ + "predictions" + ] + }, + { + "cell_type": "markdown", + "id": "ef1b4809-cc89-4318-af18-671ba2c70dd5", + "metadata": {}, + "source": [ + "### Using pre-processing in the model object\n", + "- You can also use `preprocess = True` in the `CausalTune` fit method to do preprocessing automatically\n", + "- You should specify `encoder_type`\n", + "- You should also specify `encoder_outcome` (binary target column) for the `\"woe\", \"target\"` encoders, no need for `\"onehot\", \"label\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "7d251250-c64a-43b5-b804-1e3e672acb38", + "metadata": {}, + "outputs": [], + "source": [ + "unique_values_1 = ['A', 'B', 'C', 'D', 'E']\n", + "unique_values_2 = ['F', 'G', 'H', 'I', 'J', 'K']\n", + "unique_values_3 = ['L', 'M', 'N', 'O', 'P', 'Q', 'R']\n", + "\n", + "df_train = synth_ihdp(return_df=True).iloc[:,:5]\n", + "df_test = synth_ihdp(return_df=True).iloc[:,:5]\n", + "\n", + "df_train['category_col1'] = np.random.choice(unique_values_1, len(df_train))\n", + "df_train['category_col2'] = np.random.choice(unique_values_2, len(df_train))\n", + "df_train['category_col3'] = np.random.choice(unique_values_3, len(df_train))\n", + "\n", + "df_test['category_col1'] = np.random.choice(unique_values_1, len(df_test))\n", + "df_test['category_col2'] = np.random.choice(unique_values_2, len(df_test))\n", + "df_test['category_col3'] = np.random.choice(unique_values_3, len(df_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "fef7b317-08aa-45bb-8eb1-f8c8eacae428", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3category_col1category_col2category_col3
015.599916-0.528603-0.3434551.128554AJN
106.875856-1.736945-1.8020020.383828BJP
202.996273-0.807451-0.202946-0.360898AJP
301.3662060.3900830.596582-1.850350EFM
401.963538-1.045229-0.6027100.011465DGQ
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 A \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 B \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 A \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 E \n", + "4 0 1.963538 -1.045229 -0.602710 0.011465 D \n", + "\n", + " category_col2 category_col3 \n", + "0 J N \n", + "1 J P \n", + "2 J P \n", + "3 F M \n", + "4 G Q " + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_train.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "372e47a5-1da8-4273-80d1-ad8392720b4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
treatmenty_factualx1x2x3category_col1category_col2category_col3
015.599916-0.528603-0.3434551.128554CKN
106.875856-1.736945-1.8020020.383828DIN
202.996273-0.807451-0.202946-0.360898AGP
301.3662060.3900830.596582-1.850350DJM
401.963538-1.045229-0.6027100.011465EHP
\n", + "
" + ], + "text/plain": [ + " treatment y_factual x1 x2 x3 category_col1 \\\n", + "0 1 5.599916 -0.528603 -0.343455 1.128554 C \n", + "1 0 6.875856 -1.736945 -1.802002 0.383828 D \n", + "2 0 2.996273 -0.807451 -0.202946 -0.360898 A \n", + "3 0 1.366206 0.390083 0.596582 -1.850350 D \n", + "4 0 1.963538 -1.045229 -0.602710 0.011465 E \n", + "\n", + " category_col2 category_col3 \n", + "0 K N \n", + "1 I N \n", + "2 G P \n", + "3 J M \n", + "4 H P " + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_test.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "d0f3ae9f-1f5d-44b8-8c50-953a4ce1e7ef", + "metadata": {}, + "outputs": [], + "source": [ + "cd_train = CausalityDataset(\n", + " data=df_train,\n", + " treatment='treatment',\n", + " outcomes=['y_factual'],\n", + " effect_modifiers=['x1', 'x2', 'x3', 'category_col1', 'category_col2', 'category_col3']\n", + ")\n", + "\n", + "cd_test = CausalityDataset(\n", + " data=df_test,\n", + " treatment='treatment',\n", + " outcomes=['y_factual'],\n", + " effect_modifiers=['x1', 'x2', 'x3', 'category_col1', 'category_col2', 'category_col3']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "6c1da720-1ef6-4bba-8bb5-047c9bff01d4", + "metadata": {}, + "outputs": [], + "source": [ + "ct = CausalTune(\n", + " estimator_list=[\n", + " \"DomainAdaptationLearner\",\n", + " \"CausalForestDML\",\n", + " \"ForestDRLearner\",\n", + " ],\n", + " metric=metric,\n", + " verbose=1,\n", + " components_time_budget=components_time_budget,\n", + " train_size=train_size,\n", + " outcome_model=\"auto\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "82699ee7-aafb-45c9-994f-9a0296dfe30a", + "metadata": {}, + "outputs": [], + "source": [ + "# run causaltune\n", + "ct.fit(data=cd_train, outcome=cd_train.outcomes[0], preprocess=True, encoder_type = \"label\")\n", + "\n", + "print('---------------------')\n", + "# return best estimator\n", + "print(f\"Best estimator: {ct.best_estimator}\")\n", + "# config of best estimator:\n", + "print(f\"Best config: {ct.best_config}\")\n", + "# best score:\n", + "print(f\"Best score: {ct.best_score}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "92cbd4ad-ab7d-45dd-a13c-c7c99dacd63b", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = ct.predict(cd_train, preprocess=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "d16777d6-f468-469d-813f-da1f90455d37", + "metadata": {}, + "outputs": [], + "source": [ + "predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffaa475c-2cd6-46c8-ab23-f64f0f5f1506", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}