Skip to content

Commit 8f3e508

Browse files
committed
option to iteratively run random forest
1 parent d5dffd1 commit 8f3e508

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed
Binary file not shown.
Binary file not shown.

machine_learning/feat_importance_plot.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def count_list_of_sublists(complex_list, feat_number):
137137

138138
return number_feat_lst, percent_lst
139139

140-
def plot_stacked_bar(df, x='Metadata_Time', ylabel="Percentage (%)", title="", colormap=None, rotation=45):
140+
def plot_stacked_bar(df, x='Metadata_Time', ylabel="Percentage (%)", title="", colormap=None, rotation=45, percentage_fontsize=18):
141141
ax = df.plot(x=x,
142142
kind='bar',
143143
stacked=True,
144144
color=colormap)
145-
plt.rcParams.update({'font.size': 18})
145+
plt.rcParams.update({'font.size': percentage_fontsize})
146146
plt.legend(
147147
loc='center left',
148148
bbox_to_anchor=(1.0, 0.5),
@@ -158,7 +158,33 @@ def plot_stacked_bar(df, x='Metadata_Time', ylabel="Percentage (%)", title="", c
158158
if not height == 0.0:
159159
ax.text(x+width/2,
160160
y+height/2,
161-
'{:.0f} %'.format(height),
161+
'{:.0f}%'.format(height),
162+
horizontalalignment='center',
163+
verticalalignment='center')
164+
plt.show()
165+
166+
def plot_stacked_bar_horizontal(df, x='Metadata_Time', ylabel="Percentage (%)", title="", colormap=None, rotation=45, percentage_fontsize=18):
167+
ax = df.plot(x=x,
168+
kind='barh',
169+
stacked=True,
170+
color=colormap)
171+
plt.rcParams.update({'font.size': percentage_fontsize})
172+
plt.legend(
173+
loc='center left',
174+
bbox_to_anchor=(1.0, 0.5),
175+
# reverse=True
176+
)
177+
plt.ylabel("")
178+
plt.xlabel(ylabel, fontsize=15)
179+
plt.xticks(rotation=rotation,fontsize=12)
180+
plt.title(title,fontsize=15)
181+
for p in ax.patches:
182+
width, height = p.get_width(), p.get_height()
183+
x, y = p.get_xy()
184+
if not width == 0.0:
185+
ax.text(x+width/2,
186+
y+height/2,
187+
'{:.0f}%'.format(width),
162188
horizontalalignment='center',
163189
verticalalignment='center')
164190
plt.show()

machine_learning/random_forest_utils.py

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,82 @@
22
import shap
33
import pandas as pd
44
import matplotlib.pyplot as plt
5+
import numpy as np
56

67
from sklearn.model_selection import train_test_split, cross_val_score
78
from sklearn.ensemble import RandomForestClassifier
89
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
10+
from sklearn.metrics import accuracy_score
911

10-
def random_forest_model_eval(df_, target = "", ccp = 0.08, n_estimators=10, max_depth=5, slice = False, column_slice = None, slice_to_value = None):
12+
13+
def random_forest_iterations(df_, target = "", ccp = 0.08, n_estimators=10, max_depth=5,
14+
slice = False, column_slice = None, slice_to_value = None, number_iterations=None):
15+
"""
16+
This function takes a df (which you can slice based on a column and a value), creates a train/test, and train a Random Forest Classifier model.
17+
We evaluate the model using confusion matrix, cross validation, and shap.
18+
*df_ (DataFrame): dataframe that contains X and y
19+
*target (str): column name that will be the target (what are the classes?)
20+
*slice (bool): True if wants to slice the df based on a column_slice (str) and a variable that is in that column (slice_to_value)
21+
22+
return:
23+
shap_values_t: a list with the shap_values for all the features in the df for each class and sample;
24+
X_train: the portion of the dataframe that model was trained on.
25+
"""
26+
if slice:
27+
df = df_[df_[column_slice] == slice_to_value]
28+
print(f"Looping through {column_slice} = {slice_to_value}.")
29+
else:
30+
df = df_.copy()
31+
32+
#features and metadata lists of cols
33+
feat = pycytominer.cyto_utils.features.infer_cp_features(df, metadata=False)
34+
meta = pycytominer.cyto_utils.features.infer_cp_features(df, metadata=True)
35+
#X, y and y target
36+
X = pd.DataFrame(df, columns=feat)
37+
y = pd.DataFrame(df, columns=meta)
38+
y_target = y[target]
39+
40+
all_feature_importances=[]
41+
all_train_accuracies = []
42+
all_test_accuracies = []
43+
for _ in range(number_iterations):
44+
X_train, X_test, y_train, y_test = train_test_split(X, y_target, random_state=42)
45+
#train
46+
forest = RandomForestClassifier(random_state=0, ccp_alpha=ccp, n_estimators=n_estimators, max_depth=max_depth)
47+
forest.fit(X_train, y_train)
48+
# Get feature importances for this iteration
49+
iteration_feature_importances = forest.feature_importances_
50+
51+
# Store the feature importances
52+
all_feature_importances.append(iteration_feature_importances)
53+
54+
# Predictions on training set
55+
y_train_pred = forest.predict(X_train)
56+
train_accuracy = accuracy_score(y_train, y_train_pred)
57+
all_train_accuracies.append(train_accuracy)
58+
59+
# Predictions on test set
60+
y_test_pred = forest.predict(X_test)
61+
test_accuracy = accuracy_score(y_test, y_test_pred)
62+
all_test_accuracies.append(test_accuracy)
63+
# Aggregate feature importances across iterations
64+
aggregate_feature_importances = np.mean(all_feature_importances, axis=0)
65+
66+
# Rank features based on aggregated importances
67+
feature_ranking = np.argsort(aggregate_feature_importances)[::-1]
68+
69+
# Calculate mean accuracy for training and testing sets
70+
mean_train_accuracy = np.mean(all_train_accuracies)
71+
mean_test_accuracy = np.mean(all_test_accuracies)
72+
73+
print(f"\nMean Training Accuracy: {mean_train_accuracy}")
74+
print(f"Mean Testing Accuracy: {mean_test_accuracy}")
75+
76+
return X, all_feature_importances, aggregate_feature_importances, forest
77+
78+
def random_forest_model_eval(df_, target = "", ccp = 0.08, n_estimators=10, max_depth=5,
79+
slice = False, column_slice = None, slice_to_value = None,
80+
):
1181
"""
1282
This function takes a df (which you can slice based on a column and a value), creates a train/test, and train a Random Forest Classifier model.
1383
We evaluate the model using confusion matrix, cross validation, and shap.
@@ -67,9 +137,9 @@ def random_forest_model_eval(df_, target = "", ccp = 0.08, n_estimators=10, max_
67137
row_index=2
68138
base = explainer.expected_value
69139
shap.multioutput_decision_plot(list(base), shap_values_t,
70-
row_index=row_index,
71-
feature_names=X_train.columns.to_list(),
72-
)
140+
row_index=row_index,
141+
feature_names=X_train.columns.to_list(),
142+
)
73143

74144
return shap_values_t, X_train, forest
75145

@@ -134,6 +204,49 @@ def loop_random_forest_model_eval(df_, target = "", column_to_loop = "", list_to
134204
feature_names=X_train.columns.to_list(),
135205
)
136206

207+
def col_generator(df, cols_to_join = ['Metadata_Compound', 'Metadata_Concentration']):
208+
"""
209+
Create a new column containing information from compound + concentration of compounds
210+
*cols_to_join: provide columns names to join on, order will be determined by order in this list
211+
"""
212+
col_copy = cols_to_join.copy()
213+
init = cols_to_join.pop(0) #pop the first element of the list
214+
new_col_temp = [init] #keep the first element in the list
215+
for cols in cols_to_join:
216+
temp = cols.split("_", 1) #only split metadata out
217+
print(temp[1])
218+
new_col_temp.append(temp[1])
219+
new_col = ('_'.join(new_col_temp)) #generate the new column name from the list
220+
df[new_col] = df[col_copy].astype(str).agg(' '.join, axis=1) #transform the column to str and create new metadata
221+
print("Names of the compounds + concentration: ", df[new_col].unique())
222+
223+
return df, new_col
224+
225+
def most_important_with_sd(X, all_feature_importances, aggregate_feature_importances, number_of_features_select=30, compound=None):
226+
"""
227+
"""
228+
feat_importance_sd = np.std(all_feature_importances, axis=0)
229+
feature_ranking = np.argsort(aggregate_feature_importances)[::-1]
230+
features=[]
231+
importance=[]
232+
importance_sd=[]
233+
for i, feature_index in enumerate(feature_ranking):
234+
if i < number_of_features_select:
235+
features.append(X.columns[feature_index])
236+
importance.append(aggregate_feature_importances[feature_index])
237+
importance_sd.append(feat_importance_sd[feature_index])
238+
239+
df_results = pd.DataFrame(list(zip(features, importance, importance_sd)), columns=[f'{compound}_features', f'{compound}_importance', f'{compound}_importance_sd'])
240+
241+
fig, ax = plt.subplots()
242+
df_results.sort_values(by=[f'{compound}_importance'],ascending=True).plot.barh(x=f'{compound}_features', y=f'{compound}_importance', yerr=f'{compound}_importance_sd', ax=ax,align="center")
243+
ax.set_title(f"{compound} feature importances (100 iterations)")
244+
ax.set_xlabel("Feature importance")
245+
fig.set_size_inches(18.5, 10.5)
246+
plt.show()
247+
248+
return df_results
249+
137250
############################ TO FIX LATER
138251

139252
def pruning():

0 commit comments

Comments
 (0)