Skip to content

Commit

Permalink
Optimized find_best_split for better run time
Browse files Browse the repository at this point in the history
Changed README
  • Loading branch information
andreicalin-georgescu authored May 19, 2018
1 parent 2adcd86 commit fd84909
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Functia RandomForest::predict va apela Node::predict pentru fiecare arbore de de

Functia get_random_samples va intoarce o submatrice din setul initial de date, cu num_to_return linii generate aleator, fiind unice intre ele.

Functia find_best_split intoarce o pereche (splitIndex, splitValue) care va reprezenta cel mai bun split. Pentru fiecare samples[i] se va calcula cel mai mare Information Gain, calculand entropia pentru parinte si copilul stang, respectiv drept.
Functia find_best_split intoarce o pereche (splitIndex, splitValue) care va reprezenta cel mai bun split. Pentru fiecare samples[i] se va calcula cel mai mare Information Gain, calculand entropia pentru parinte si copilul stang, respectiv drept. Pentru optimizare s-a ales o medie a valorilor.

Functia train verifica daca nodurile din samples au aceeasi clasa si transforma nodul curent in frunza. In caz contrar, daca nu se va gasi un split bun, nodul curent devine frunza. Altfel se va lua cel mai bun split si se va continua apelarea recursiva a functiei pentru copilul stang, respectiv drept.

Expand Down
36 changes: 20 additions & 16 deletions decisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,32 @@ pair<int, int> find_best_split(const vector<vector<int>> &samples,
int splitIndex = -1, splitValue = -1;
float max = -1;
float ig = 0;
for (int i = 1; i < samples[0].size(); ++i){
float H_parent = get_entropy_by_indexes(samples,
float H_parent = get_entropy_by_indexes(samples,
dimensions);
for (int i = 1; i < samples[0].size(); ++i){
unique = compute_unique(samples, i);
float avg, sum = 0, k = 0;
for (int j = 0; j < unique.size(); ++j){
if (unique[j] < 5 || unique[j] > 250){
children = get_split_as_indexes(samples, i, unique[j]);
float H_left = get_entropy_by_indexes(samples, children.first);
float H_right = get_entropy_by_indexes
(samples, children.second);
ig = H_parent - (children.first.size() * H_left +
children.second.size()
* H_right)
/(children.first.size() +
children.second.size());
if (ig > max){
max = ig;
splitIndex = i;
splitValue = unique[j];
}
sum += unique[j];
k++;
}
}
avg = sum / k;
children = get_split_as_indexes(samples, i, avg);
float H_left = get_entropy_by_indexes(samples, children.first);
float H_right = get_entropy_by_indexes
(samples, children.second);
ig = H_parent - (children.first.size() * H_left +
children.second.size()
* H_right)
/(children.first.size() +
children.second.size());
if (ig > max){
max = ig;
splitIndex = i;
splitValue = avg;
}
}
return pair<int, int>(splitIndex, splitValue);
}
Expand Down

0 comments on commit fd84909

Please sign in to comment.