-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
60 lines (48 loc) · 2.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st
import time
from model import predictXGBoost, predictRandomForest, get_data, plot_data, is_valid_ticker, trendFeatures, backTest
st.set_page_config(page_title="Stock Market Prediction", page_icon=":chart_with_upwards_trend:", layout="wide")
st.title("Stock Market Prediction")
ticker = st.text_input("Enter ticker symbol")
#Checks to make sure a ticker symbol is entered and its validity before displaying data
#Reports and error if ticker symbol entered is not valid
if ticker and is_valid_ticker(ticker):
st.divider()
#Get the data set from ticker symbol
st.write("Data Frame")
data_set = get_data(ticker)
#Display raw table as a table
st.dataframe(data_set, use_container_width=True)
#Display data as a chart
fig = plot_data(data_set)
st.plotly_chart(fig, use_container_width=True)
st.divider()
#Choose the ML prediction model used
model = st.selectbox("Select Prediction Model", ("XGBoost Regressor", "Random Forest Classifier"))
if st.button("Make Prediction"):
with st.spinner("Running model..."):
time.sleep(1.5)
#Create new features and update the data set, remove all NA boxes
predictors, data_set = trendFeatures(data_set)
data_set = data_set.dropna()
if model == "XGBoost Regressor":
prediction = predictXGBoost(data_set, predictors)
#Display data used by the ML model for prediction
st.dataframe(data_set.loc[:,["Tomorrow", "Target"] + predictors], use_container_width=True)
st.write("According to the model, there is a " + str(prediction) + "% chance that " + ticker + " will increase tomorrow" )
if prediction > 50:
st.markdown("Chances are the stock price will **:blue[increase]** tomorrow")
elif prediction < 50:
st.markdown("Chances are the stock price will **:red[decrease]** tomorrow")
elif model == "Random Forest Classifier":
prediction, _ = predictRandomForest(data_set, predictors)
#Display data used by the ML model for prediction
st.dataframe(data_set.loc[:,["Tomorrow", "Target"] + predictors], use_container_width=True)
#Random Forest model predicts 0s and 1s, nothing in between
if prediction == 100:
st.markdown("Chances are the stock price will **:blue[increase]** tomorrow")
elif prediction == 0:
st.markdown("Chances are the stock price will **:red[decrease]** tomorrow")
#st.markdown(backTest(data_set, predictors))
elif ticker and not(is_valid_ticker(ticker)):
st.error("Invalid or Delisted Ticker Symbol")