From fd849096b7fbbbc7053ddde5cd060d8f8373cdb7 Mon Sep 17 00:00:00 2001 From: Mutra12 <39122016+Mutra12@users.noreply.github.com> Date: Sat, 19 May 2018 15:14:50 +0300 Subject: [PATCH] Optimized find_best_split for better run time Changed README --- README | 2 +- decisionTree.cpp | 36 ++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/README b/README index a253c1a..3a78350 100644 --- a/README +++ b/README @@ -32,7 +32,7 @@ Functia RandomForest::predict va apela Node::predict pentru fiecare arbore de de Functia get_random_samples va intoarce o submatrice din setul initial de date, cu num_to_return linii generate aleator, fiind unice intre ele. -Functia find_best_split intoarce o pereche (splitIndex, splitValue) care va reprezenta cel mai bun split. Pentru fiecare samples[i] se va calcula cel mai mare Information Gain, calculand entropia pentru parinte si copilul stang, respectiv drept. +Functia find_best_split intoarce o pereche (splitIndex, splitValue) care va reprezenta cel mai bun split. Pentru fiecare samples[i] se va calcula cel mai mare Information Gain, calculand entropia pentru parinte si copilul stang, respectiv drept. Pentru optimizare s-a ales o medie a valorilor. Functia train verifica daca nodurile din samples au aceeasi clasa si transforma nodul curent in frunza. In caz contrar, daca nu se va gasi un split bun, nodul curent devine frunza. Altfel se va lua cel mai bun split si se va continua apelarea recursiva a functiei pentru copilul stang, respectiv drept. diff --git a/decisionTree.cpp b/decisionTree.cpp index 6bc5db4..2b95521 100644 --- a/decisionTree.cpp +++ b/decisionTree.cpp @@ -69,28 +69,32 @@ pair find_best_split(const vector> &samples, int splitIndex = -1, splitValue = -1; float max = -1; float ig = 0; - for (int i = 1; i < samples[0].size(); ++i){ - float H_parent = get_entropy_by_indexes(samples, + float H_parent = get_entropy_by_indexes(samples, dimensions); + for (int i = 1; i < samples[0].size(); ++i){ unique = compute_unique(samples, i); + float avg, sum = 0, k = 0; for (int j = 0; j < unique.size(); ++j){ if (unique[j] < 5 || unique[j] > 250){ - children = get_split_as_indexes(samples, i, unique[j]); - float H_left = get_entropy_by_indexes(samples, children.first); - float H_right = get_entropy_by_indexes - (samples, children.second); - ig = H_parent - (children.first.size() * H_left + - children.second.size() - * H_right) - /(children.first.size() + - children.second.size()); - if (ig > max){ - max = ig; - splitIndex = i; - splitValue = unique[j]; - } + sum += unique[j]; + k++; } } + avg = sum / k; + children = get_split_as_indexes(samples, i, avg); + float H_left = get_entropy_by_indexes(samples, children.first); + float H_right = get_entropy_by_indexes + (samples, children.second); + ig = H_parent - (children.first.size() * H_left + + children.second.size() + * H_right) + /(children.first.size() + + children.second.size()); + if (ig > max){ + max = ig; + splitIndex = i; + splitValue = avg; + } } return pair(splitIndex, splitValue); }