-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathML_app.py
131 lines (111 loc) · 5.99 KB
/
ML_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix, plot_roc_curve, plot_precision_recall_curve
from sklearn.metrics import precision_score, recall_score
st.set_option('deprecation.showPyplotGlobalUse', False)
def main():
st.title("Binary Classification Web App")
st.sidebar.title("Binary Classisfier Hyperparameter Tuning")
st.markdown(
"""
<style>
.sidebar .sidebar-content {
background-image: linear-gradient(#2e7bcf,#2e7bcf);
color: yellow;
},
</style>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown("Are your mushrooms 🍄 edible or poisonous?")
st.markdown("Are your mushrooms 🍄 edible or poisonous?")
@st.cache(persist=True) #Using the st.cache function decorator to cache the output of a function to disk,
#so that if the function or its inputs remain unchanged, we do not unneccesarily call it again.
def load_data():
data = pd.read_csv('mushrooms.csv')
label = LabelEncoder()
for col in data.columns:
data[col] = label.fit_transform(data[col])
return data
@st.cache(persist=True)
def split(df):
y = df.type
x = df.drop(columns=["type"])
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.3, random_state=50)
return x_train, x_test, y_train, y_test
def plot_metrics(metrics_list):
if 'Confusion Matrix' in metrics_list:
st.subheader("Confusion Matrix")
plot_confusion_matrix(model, x_test, y_test, display_labels = class_names)
st.pyplot()
if 'ROC Curve' in metrics_list:
st.subheader("ROC Curve")
plot_roc_curve(model, x_test, y_test)
st.pyplot()
if 'Precision-Recall Curve' in metrics_list:
st.subheader("Precision-Recall Curve")
plot_precision_recall_curve(model,x_test,y_test)
st.pyplot()
df = load_data()
x_train, x_test, y_train, y_test = split(df)
if st.sidebar.checkbox("Show raw data", False):
st.subheader("Mushroom Data Set (Classification)")
st.write(df)
class_names = ['edible', 'poisonous']
st.sidebar.subheader("Choose Classifier")
classifier = st.sidebar.selectbox("Classifier",("Support Vector Machine (SVM)", "Logistic Regression", "Random Forest"))
if classifier == 'Support Vector Machine (SVM)':
st.sidebar.subheader("Model Hyperparameters")
C = st.sidebar.number_input("C (Inverse regularization parameter)", 0.01,10.0,step=0.01,key="C")
kernel = st.sidebar.radio("Kernel", ("rbf","linear"), key="kernel")
gamma = st.sidebar.radio("Gamm (Kernel Coefficient", ("scale","auto"), key='gamma')
metrics = st.sidebar.multiselect("What metrics would you like to plot?", ('Confusion Matrix', 'Precision-Recall Curve', 'ROC Curve'))
if st.sidebar.button("Classify", key="classify"):
st.subheader("Support Vector Machine (SVM) Results")
model = SVC(C=C, kernel = kernel, gamma=gamma)
model.fit(x_train, y_train)
accuracy = model.score(x_test,y_test)
y_pred = model.predict(x_test)
st.write("Accuracy: ",accuracy.round(2))
st.write("Precision: ",precision_score(y_test,y_pred, labels=class_names).round(2))
st.write("Recall: ",recall_score(y_test,y_pred, labels=class_names).round(2))
plot_metrics(metrics)
if classifier == "Logistic Regression":
st.sidebar.subheader("Model Hyperparameters")
C = st.sidebar.number_input("C (Inverse regularization parameter)", 0.01,10.0,step=0.01,key="C_LR")
max_iter = st.sidebar.slider("Maximum Iterations", 100, 500, key="max_iter")
metrics = st.sidebar.multiselect("What metrics would you like to plot?", ('Confusion Matrix', 'Precision-Recall Curve', 'ROC Curve'))
if st.sidebar.button("Classify", key="classify"):
st.subheader("Logistic Regression Results")
model = LogisticRegression(C=C, max_iter=max_iter)
model.fit(x_train, y_train)
accuracy = model.score(x_test,y_test)
y_pred = model.predict(x_test)
st.write("Accuracy: ",accuracy.round(2))
st.write("Precision: ",precision_score(y_test,y_pred, labels=class_names).round(2))
st.write("Recall: ",recall_score(y_test,y_pred, labels=class_names).round(2))
plot_metrics(metrics)
if classifier == 'Random Forest':
st.sidebar.subheader("Model Hyperparameters")
n_estimators = st.sidebar.number_input("The number of decision trees",100,500,step=10,key="n_estimators")
max_depth = st.sidebar.number_input("The maximum depth of the trees",1,20,step=1,key="max_depth")
bootstrap = st.sidebar.radio("Bootstrap samples when building trees", ('True','False'),key='bootstrap')
metrics = st.sidebar.multiselect("What metrics would you like to plot?", ('Confusion Matrix', 'Precision-Recall Curve', 'ROC Curve'))
if st.sidebar.button("Classify", key="classify"):
st.subheader("Random Forest Results")
model = RandomForestClassifier(n_estimators=n_estimators,max_depth=max_depth,bootstrap=bootstrap,n_jobs=-1)
model.fit(x_train, y_train)
accuracy = model.score(x_test,y_test)
y_pred = model.predict(x_test)
st.write("Accuracy: ",accuracy.round(2))
st.write("Precision: ",precision_score(y_test,y_pred, labels=class_names).round(2))
st.write("Recall: ",recall_score(y_test,y_pred, labels=class_names).round(2))
plot_metrics(metrics)
if __name__ == '__main__':
main()