-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch_spaces.py
54 lines (50 loc) · 2.57 KB
/
search_spaces.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Project 2 - 5/4/23
# Joshua Adams, Weston Beebe, Parth Patel, Jonathan Sanderson, Samuel Sylvester
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
class ClfSearchSpace():
# parameter grids for GridSearchCV
# Shotgun approach, one of these should work
dt_space = {'clf__criterion': ['gini', 'entropy'],
'clf__min_samples_split': [2, 5, 10],
'clf__min_samples_leaf': [1, 2, 5],
'clf__class_weight': [None, 'balanced'],
'clf__max_depth': [None, 5, 10, 20, 50, 100]}
rf_space = {'clf__n_estimators': [10, 50, 100],
'clf__criterion': ['gini', 'entropy'],
'clf__min_samples_split': [2, 5, 10],
'clf__min_samples_leaf': [1, 2, 5],
'clf__class_weight': ['balanced', 'balanced_subsample', None],
'clf__bootstrap': [True, False]}
ada_space = {'clf__estimator': [DecisionTreeClassifier(),],
#RandomForestClassifier(n_estimators=10, criterion='gini',
# class_weight='balanced', bootstrap=True),],
'clf__n_estimators': [10, 50, 100, 200],
'clf__learning_rate': [0.1, 0.5, 0.9, 1.0]}
svm_space = {'clf__C': [0.1, 0.5, 1, 5],
'clf__kernel': ['rbf', 'linear'],
'clf__gamma': ['scale', 'auto', .01, 0.001, 0.0001],
'clf__class_weight': [None, 'balanced']}
kn_space = {'clf__n_neighbors': [3, 5, 7, 9, 11, 13, 15, 17, 19],
'clf__metric': ['euclidean', 'manhattan', 'chebyshev', 'minkowski']}
mlp_space = {'clf__activation': ['tanh', 'relu'],
'clf__solver': ['sgd', 'adam'],
'clf__alpha': [0.0001, 0.001, 0.01],
'clf__hidden_layer_sizes': [(10,), (50,), (100,), (200,), (50, 50), (100, 100)],
'clf__learning_rate': ['constant', 'invscaling', 'adaptive']}
lda_space = [{'clf__solver': ['svd']},
{'clf__solver': ['lsqr', 'eigen'],
'clf__shrinkage': ['auto', None]}]
# lookup table
search_space = {'DT': dt_space,
'RF': rf_space,
'ET': rf_space,
'Ada': ada_space,
'SVM': svm_space,
'KN': kn_space,
'MLP': mlp_space,
'LDA': lda_space,
'GNB': {},
}
def get_search_space(self, clf_name):
return self.search_space[clf_name]