Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polished code, added a dataset, improved the readability #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 66 additions & 48 deletions decision tree classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {
"scrolled": false
},
Expand Down Expand Up @@ -66,99 +66,99 @@
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>5.4</td>\n",
" <td>3.9</td>\n",
" <td>1.7</td>\n",
" <td>0.4</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>4.6</td>\n",
" <td>3.4</td>\n",
" <td>1.4</td>\n",
" <td>0.3</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>5.0</td>\n",
" <td>3.4</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>4.4</td>\n",
" <td>2.9</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>4.9</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.1</td>\n",
" <td>0</td>\n",
" <td>Setosa</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal_length sepal_width petal_length petal_width type\n",
"0 5.1 3.5 1.4 0.2 0\n",
"1 4.9 3.0 1.4 0.2 0\n",
"2 4.7 3.2 1.3 0.2 0\n",
"3 4.6 3.1 1.5 0.2 0\n",
"4 5.0 3.6 1.4 0.2 0\n",
"5 5.4 3.9 1.7 0.4 0\n",
"6 4.6 3.4 1.4 0.3 0\n",
"7 5.0 3.4 1.5 0.2 0\n",
"8 4.4 2.9 1.4 0.2 0\n",
"9 4.9 3.1 1.5 0.1 0"
" sepal_length sepal_width petal_length petal_width type\n",
"0 5.1 3.5 1.4 0.2 Setosa\n",
"1 4.9 3.0 1.4 0.2 Setosa\n",
"2 4.7 3.2 1.3 0.2 Setosa\n",
"3 4.6 3.1 1.5 0.2 Setosa\n",
"4 5.0 3.6 1.4 0.2 Setosa\n",
"5 5.4 3.9 1.7 0.4 Setosa\n",
"6 4.6 3.4 1.4 0.3 Setosa\n",
"7 5.0 3.4 1.5 0.2 Setosa\n",
"8 4.4 2.9 1.4 0.2 Setosa\n",
"9 4.9 3.1 1.5 0.1 Setosa"
]
},
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -178,7 +178,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -206,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -221,6 +221,7 @@
" self.min_samples_split = min_samples_split\n",
" self.max_depth = max_depth\n",
" \n",
" \n",
" def build_tree(self, dataset, curr_depth=0):\n",
" ''' recursive function to build the tree ''' \n",
" \n",
Expand All @@ -229,14 +230,17 @@
" \n",
" # split until stopping conditions are met\n",
" if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n",
" \n",
" # find the best split\n",
" best_split = self.get_best_split(dataset, num_samples, num_features)\n",
" best_split = self.get_best_split(dataset, num_features)\n",
" \n",
" # check if information gain is positive\n",
" if best_split[\"info_gain\"]>0:\n",
" # recur left\n",
" left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n",
" # recur right\n",
" right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n",
" \n",
" # return decision node\n",
" return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n",
" left_subtree, right_subtree, best_split[\"info_gain\"])\n",
Expand All @@ -246,7 +250,8 @@
" # return leaf node\n",
" return Node(value=leaf_value)\n",
" \n",
" def get_best_split(self, dataset, num_samples, num_features):\n",
" \n",
" def get_best_split(self, dataset, num_features):\n",
" ''' function to find the best split '''\n",
" \n",
" # dictionary to store the best split\n",
Expand All @@ -257,16 +262,19 @@
" for feature_index in range(num_features):\n",
" feature_values = dataset[:, feature_index]\n",
" possible_thresholds = np.unique(feature_values)\n",
" \n",
" # loop over all the feature values present in the data\n",
" for threshold in possible_thresholds:\n",
" # get current split\n",
" dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n",
" \n",
" # check if childs are not null\n",
" if len(dataset_left)>0 and len(dataset_right)>0:\n",
" y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n",
" # compute information gain\n",
" curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n",
" # update the best split if needed\n",
"\n",
" # update the best split if needed \n",
" if curr_info_gain>max_info_gain:\n",
" best_split[\"feature_index\"] = feature_index\n",
" best_split[\"threshold\"] = threshold\n",
Expand All @@ -278,24 +286,28 @@
" # return best split\n",
" return best_split\n",
" \n",
" \n",
" def split(self, dataset, feature_index, threshold):\n",
" ''' function to split the data '''\n",
" \n",
" dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n",
" dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n",
" return dataset_left, dataset_right\n",
" \n",
" \n",
" def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n",
" ''' function to compute information gain '''\n",
" \n",
" weight_l = len(l_child) / len(parent)\n",
" weight_r = len(r_child) / len(parent)\n",
" \n",
" if mode==\"gini\":\n",
" gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n",
" else:\n",
" gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n",
" return gain\n",
" \n",
" \n",
" def entropy(self, y):\n",
" ''' function to compute entropy '''\n",
" \n",
Expand All @@ -306,6 +318,7 @@
" entropy += -p_cls * np.log2(p_cls)\n",
" return entropy\n",
" \n",
" \n",
" def gini_index(self, y):\n",
" ''' function to compute gini index '''\n",
" \n",
Expand All @@ -316,12 +329,14 @@
" gini += p_cls**2\n",
" return 1 - gini\n",
" \n",
" \n",
" def calculate_leaf_value(self, Y):\n",
" ''' function to compute leaf node '''\n",
" \n",
" Y = list(Y)\n",
" return max(Y, key=Y.count)\n",
" \n",
" \n",
" def print_tree(self, tree=None, indent=\" \"):\n",
" ''' function to print the tree '''\n",
" \n",
Expand All @@ -332,30 +347,33 @@
" print(tree.value)\n",
"\n",
" else:\n",
" print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.info_gain)\n",
" print(f'X_{str(tree.feature_index)} <= {tree.threshold} ? {tree.info_gain}')\n",
" print(\"%sleft:\" % (indent), end=\"\")\n",
" self.print_tree(tree.left, indent + indent)\n",
" print(\"%sright:\" % (indent), end=\"\")\n",
" self.print_tree(tree.right, indent + indent)\n",
" \n",
" \n",
" def fit(self, X, Y):\n",
" ''' function to train the tree '''\n",
" \n",
" dataset = np.concatenate((X, Y), axis=1)\n",
" self.root = self.build_tree(dataset)\n",
" \n",
" \n",
" def predict(self, X):\n",
" ''' function to predict new dataset '''\n",
" \n",
" preditions = [self.make_prediction(x, self.root) for x in X]\n",
" return preditions\n",
" \n",
" \n",
" def make_prediction(self, x, tree):\n",
" ''' function to predict a single data point '''\n",
" \n",
" if tree.value!=None: return tree.value\n",
" if tree.value is not None: return tree.value\n",
" feature_val = x[tree.feature_index]\n",
" if feature_val<=tree.threshold:\n",
" if feature_val <= tree.threshold:\n",
" return self.make_prediction(x, tree.left)\n",
" else:\n",
" return self.make_prediction(x, tree.right)"
Expand All @@ -370,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -389,24 +407,24 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_2 <= 1.9 ? 0.33741385372714494\n",
" left:0.0\n",
" right:X_3 <= 1.5 ? 0.427106638180289\n",
" left:X_2 <= 4.9 ? 0.05124653739612173\n",
" left:1.0\n",
" right:2.0\n",
" right:X_2 <= 5.0 ? 0.019631171921475288\n",
" left:X_1 <= 2.8 ? 0.20833333333333334\n",
" left:2.0\n",
" right:1.0\n",
" right:2.0\n"
"X_2 <= 1.9 ? 0.33741385372714494\n",
" left:Setosa\n",
" right:X_3 <= 1.5 ? 0.427106638180289\n",
" left:X_2 <= 4.9 ? 0.05124653739612173\n",
" left:Versicolor\n",
" right:Virginica\n",
" right:X_2 <= 5.0 ? 0.019631171921475288\n",
" left:X_1 <= 2.8 ? 0.20833333333333334\n",
" left:Virginica\n",
" right:Versicolor\n",
" right:Virginica\n"
]
}
],
Expand All @@ -425,7 +443,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -434,7 +452,7 @@
"0.9333333333333333"
]
},
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -462,7 +480,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading