diff --git a/examples/lime/LIME.ipynb b/examples/lime/LIME.ipynb
index 8495304..f1a97d2 100644
--- a/examples/lime/LIME.ipynb
+++ b/examples/lime/LIME.ipynb
@@ -4,23 +4,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Tutorial to invoke LIME explainers via AIX360\n",
+ "# Tutorial to invoke LIME explainers via aix360\n",
"\n",
- "There are two ways to use LIME for explanations after installing aix360:\n",
- "- Since LIME is installed along with other libraries in aix360, it can simply be invoked directly.\n",
- "- LIME can also be invoked in a manner similar to other algorithms in aix360 via the implemented wrapper classes.\n",
+ "There are two ways to use [LIME](https://github.com/marcotcr/lime) explainers after installing aix360:\n",
+ "- [Approach 1 (aix360 style)](#approach1): LIME explainers can be invoked in a manner similar to other explainer algorithms in aix360 via the implemented wrapper classes.\n",
+ "- [Approach 2 (original style)](#approach2): Since LIME comes pre-installed in aix360, the explainers can simply be invoked directly.\n",
"\n",
- "This notebook shows both these approaches to invoke LIME in aix360. The notebook is based on the following example from the original LIME tutorial: \n",
- "https://github.com/marcotcr/lime/blob/master/doc/notebooks/Lime%20-%20multiclass.ipynb\n"
+ "This notebook showcases both these approaches to invoke LIME. The notebook is based on the following example from the original LIME tutorial: https://marcotcr.github.io/lime/tutorials/Lime%20-%20multiclass.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Approach 1 (original style)\n",
- "### LIME Text Example \n",
- "- Example from https://marcotcr.github.io/lime/tutorials/Lime%20-%20multiclass.html"
+ "## Approach 1 (aix360 style)\n",
+ "\n",
+ "- Note the import statement related to LimeTextExplainer"
]
},
{
@@ -36,7 +35,8 @@
"import sklearn.ensemble\n",
"import sklearn.metrics\n",
"\n",
- "from lime.lime_text import LimeTextExplainer"
+ "# Importing LimeTextExplainer (aix360 sytle)\n",
+ "from aix360.algorithms.lime import LimeTextExplainer"
]
},
{
@@ -44,6 +44,24 @@
"execution_count": 2,
"metadata": {},
"outputs": [],
+ "source": [
+ "# Supress jupyter warnings if required for cleaner output\n",
+ "import warnings\n",
+ "warnings.simplefilter('ignore')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Fetching data, training a classifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
"source": [
"from sklearn.datasets import fetch_20newsgroups\n",
"newsgroups_train = fetch_20newsgroups(subset='train')\n",
@@ -56,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -73,7 +91,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -84,7 +102,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -93,7 +111,7 @@
"MultinomialNB(alpha=0.01, class_prior=None, fit_prior=True)"
]
},
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -106,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -115,7 +133,7 @@
"0.8350184193998174"
]
},
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -125,9 +143,16 @@
"sklearn.metrics.f1_score(newsgroups_test.target, pred, average='weighted')"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Explaining predictions using lime"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -137,7 +162,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -153,15 +178,6 @@
"print(c.predict_proba([newsgroups_test.data[0]]).round(3))"
]
},
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "explainer = LimeTextExplainer(class_names=class_names)"
- ]
- },
{
"cell_type": "code",
"execution_count": 10,
@@ -171,12 +187,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\n"
+ "\n"
]
}
],
"source": [
- "print(type(explainer))"
+ "limeexplainer = LimeTextExplainer(class_names=class_names)\n",
+ "print(type(limeexplainer))"
]
},
{
@@ -184,14 +201,6 @@
"execution_count": 11,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/anaconda3/envs/aix360/lib/python3.6/site-packages/lime/lime_text.py:116: FutureWarning: split() requires a non-empty pattern match.\n",
- " self.as_list = [s for s in splitter.split(self.raw) if s]\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
@@ -204,7 +213,8 @@
],
"source": [
"idx = 1340\n",
- "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])\n",
+ "# aix360 style for explaining input instances\n",
+ "exp = limeexplainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])\n",
"print('Document id: %d' % idx)\n",
"print('Predicted class =', class_names[nb.predict(test_vectors[idx]).reshape(1,-1)[0,0]])\n",
"print('True class: %s' % class_names[newsgroups_test.target[idx]])"
@@ -220,20 +230,20 @@
"output_type": "stream",
"text": [
"Explanation for class atheism\n",
- "('Caused', 0.2638134141747723)\n",
- "('Rice', 0.138346697297952)\n",
- "('Genocide', 0.12506034401934968)\n",
- "('owlnet', -0.08968526910905454)\n",
- "('scri', -0.08426182410065453)\n",
- "('Semitic', -0.07968947032986141)\n",
+ "('Caused', 0.2598306611703779)\n",
+ "('Rice', 0.1476407287363688)\n",
+ "('Genocide', 0.13182300286384235)\n",
+ "('scri', -0.09419412002335747)\n",
+ "('certainty', -0.09272741554383297)\n",
+ "('owlnet', -0.08993298975187172)\n",
"\n",
"Explanation for class mideast\n",
- "('fsu', -0.06012358222232414)\n",
- "('Theism', -0.05266356722531139)\n",
- "('Luther', -0.049997961626805004)\n",
- "('jews', 0.035879721280743154)\n",
- "('Caused', -0.03457976903151407)\n",
- "('PBS', 0.034166969785271686)\n"
+ "('fsu', -0.05535199329931831)\n",
+ "('Theism', -0.05150493402341905)\n",
+ "('Luther', -0.04742295494991691)\n",
+ "('jews', 0.03810985477960863)\n",
+ "('Caused', -0.037706845450166455)\n",
+ "('PBS', 0.031228586744514376)\n"
]
}
],
@@ -250,14 +260,6 @@
"execution_count": 13,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/anaconda3/envs/aix360/lib/python3.6/site-packages/lime/lime_text.py:116: FutureWarning: split() requires a non-empty pattern match.\n",
- " self.as_list = [s for s in splitter.split(self.raw) if s]\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
@@ -267,7 +269,8 @@
}
],
"source": [
- "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, top_labels=2)\n",
+ "# aix360 style for explaining input instances\n",
+ "exp = limeexplainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, top_labels=2)\n",
"print(exp.available_labels())"
]
},
@@ -37369,10 +37372,10 @@
"/***/ })\n",
"/******/ ]);\n",
"//# sourceMappingURL=bundle.js.map \n",
- " \n",
+ " \n",
" \n",
" \n",
@@ -74503,10 +74506,10 @@
"/***/ })\n",
"/******/ ]);\n",
"//# sourceMappingURL=bundle.js.map \n",
- " \n",
+ " \n",
" \n",
" \n",
"